anthonyrusso commited on
Commit
f1e9197
1 Parent(s): b3ff8a5

upload audiocraft

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audiocraft/__init__.py +26 -0
  2. audiocraft/adversarial/__init__.py +22 -0
  3. audiocraft/adversarial/discriminators/__init__.py +10 -0
  4. audiocraft/adversarial/discriminators/base.py +34 -0
  5. audiocraft/adversarial/discriminators/mpd.py +106 -0
  6. audiocraft/adversarial/discriminators/msd.py +126 -0
  7. audiocraft/adversarial/discriminators/msstftd.py +134 -0
  8. audiocraft/adversarial/losses.py +228 -0
  9. audiocraft/data/__init__.py +10 -0
  10. audiocraft/data/audio.py +216 -0
  11. audiocraft/data/audio_dataset.py +587 -0
  12. audiocraft/data/audio_utils.py +176 -0
  13. audiocraft/data/info_audio_dataset.py +110 -0
  14. audiocraft/data/music_dataset.py +270 -0
  15. audiocraft/data/sound_dataset.py +330 -0
  16. audiocraft/data/zip.py +76 -0
  17. audiocraft/environment.py +176 -0
  18. audiocraft/grids/__init__.py +6 -0
  19. audiocraft/grids/_base_explorers.py +80 -0
  20. audiocraft/grids/audiogen/__init__.py +6 -0
  21. audiocraft/grids/audiogen/audiogen_base_16khz.py +23 -0
  22. audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py +68 -0
  23. audiocraft/grids/compression/__init__.py +6 -0
  24. audiocraft/grids/compression/_explorers.py +55 -0
  25. audiocraft/grids/compression/debug.py +31 -0
  26. audiocraft/grids/compression/encodec_audiogen_16khz.py +29 -0
  27. audiocraft/grids/compression/encodec_base_24khz.py +28 -0
  28. audiocraft/grids/compression/encodec_musicgen_32khz.py +34 -0
  29. audiocraft/grids/diffusion/4_bands_base_32khz.py +27 -0
  30. audiocraft/grids/diffusion/__init__.py +6 -0
  31. audiocraft/grids/diffusion/_explorers.py +66 -0
  32. audiocraft/grids/musicgen/__init__.py +6 -0
  33. audiocraft/grids/musicgen/_explorers.py +93 -0
  34. audiocraft/grids/musicgen/musicgen_base_32khz.py +43 -0
  35. audiocraft/grids/musicgen/musicgen_base_cached_32khz.py +67 -0
  36. audiocraft/grids/musicgen/musicgen_clapemb_32khz.py +32 -0
  37. audiocraft/grids/musicgen/musicgen_melody_32khz.py +65 -0
  38. audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py +99 -0
  39. audiocraft/losses/__init__.py +21 -0
  40. audiocraft/losses/balancer.py +136 -0
  41. audiocraft/losses/sisnr.py +92 -0
  42. audiocraft/losses/specloss.py +149 -0
  43. audiocraft/losses/stftloss.py +207 -0
  44. audiocraft/metrics/__init__.py +14 -0
  45. audiocraft/metrics/chroma_cosinesim.py +72 -0
  46. audiocraft/metrics/clap_consistency.py +84 -0
  47. audiocraft/metrics/fad.py +329 -0
  48. audiocraft/metrics/kld.py +220 -0
  49. audiocraft/metrics/rvm.py +110 -0
  50. audiocraft/metrics/visqol.py +216 -0
audiocraft/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ AudioCraft is a general framework for training audio generative models.
8
+ At the moment we provide the training code for:
9
+
10
+ - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11
+ text-to-music and melody+text autoregressive generative model.
12
+ For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13
+ `audiocraft.models.musicgen.MusicGen`.
14
+ - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15
+ text-to-general-audio generative model.
16
+ - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17
+ neural audio codec which provides an excellent tokenizer for autoregressive language models.
18
+ See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19
+ - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20
+ improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21
+ """
22
+
23
+ # flake8: noqa
24
+ from . import data, modules, models
25
+
26
+ __version__ = '1.0.0'
audiocraft/adversarial/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Adversarial losses and discriminator architectures."""
7
+
8
+ # flake8: noqa
9
+ from .discriminators import (
10
+ MultiPeriodDiscriminator,
11
+ MultiScaleDiscriminator,
12
+ MultiScaleSTFTDiscriminator
13
+ )
14
+ from .losses import (
15
+ AdversarialLoss,
16
+ AdvLossType,
17
+ get_adv_criterion,
18
+ get_fake_criterion,
19
+ get_real_criterion,
20
+ FeatLossType,
21
+ FeatureMatchingLoss
22
+ )
audiocraft/adversarial/discriminators/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .mpd import MultiPeriodDiscriminator
9
+ from .msd import MultiScaleDiscriminator
10
+ from .msstftd import MultiScaleSTFTDiscriminator
audiocraft/adversarial/discriminators/base.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ FeatureMapType = tp.List[torch.Tensor]
15
+ LogitsType = torch.Tensor
16
+ MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17
+
18
+
19
+ class MultiDiscriminator(ABC, nn.Module):
20
+ """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21
+ """
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @abstractmethod
26
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27
+ ...
28
+
29
+ @property
30
+ @abstractmethod
31
+ def num_discriminators(self) -> int:
32
+ """Number of discriminators.
33
+ """
34
+ ...
audiocraft/adversarial/discriminators/mpd.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from ...modules import NormConv2d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
18
+ return int((kernel_size * dilation - dilation) / 2)
19
+
20
+
21
+ class PeriodDiscriminator(nn.Module):
22
+ """Period sub-discriminator.
23
+
24
+ Args:
25
+ period (int): Period between samples of audio.
26
+ in_channels (int): Number of input channels.
27
+ out_channels (int): Number of output channels.
28
+ n_layers (int): Number of convolutional layers.
29
+ kernel_sizes (list of int): Kernel sizes for convolutions.
30
+ stride (int): Stride for convolutions.
31
+ filters (int): Initial number of filters in convolutions.
32
+ filters_scale (int): Multiplier of number of filters as we increase depth.
33
+ max_filters (int): Maximum number of filters.
34
+ norm (str): Normalization method.
35
+ activation (str): Activation function.
36
+ activation_params (dict): Parameters to provide to the activation function.
37
+ """
38
+ def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39
+ n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40
+ filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}):
43
+ super().__init__()
44
+ self.period = period
45
+ self.n_layers = n_layers
46
+ self.activation = getattr(torch.nn, activation)(**activation_params)
47
+ self.convs = nn.ModuleList()
48
+ in_chs = in_channels
49
+ for i in range(self.n_layers):
50
+ out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51
+ eff_stride = 1 if i == self.n_layers - 1 else stride
52
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53
+ padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54
+ in_chs = out_chs
55
+ self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56
+ padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ fmap = []
60
+ # 1d to 2d
61
+ b, c, t = x.shape
62
+ if t % self.period != 0: # pad first
63
+ n_pad = self.period - (t % self.period)
64
+ x = F.pad(x, (0, n_pad), 'reflect')
65
+ t = t + n_pad
66
+ x = x.view(b, c, t // self.period, self.period)
67
+
68
+ for conv in self.convs:
69
+ x = conv(x)
70
+ x = self.activation(x)
71
+ fmap.append(x)
72
+ x = self.conv_post(x)
73
+ fmap.append(x)
74
+ # x = torch.flatten(x, 1, -1)
75
+
76
+ return x, fmap
77
+
78
+
79
+ class MultiPeriodDiscriminator(MultiDiscriminator):
80
+ """Multi-Period (MPD) Discriminator.
81
+
82
+ Args:
83
+ in_channels (int): Number of input channels.
84
+ out_channels (int): Number of output channels.
85
+ periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86
+ **kwargs: Additional args for `PeriodDiscriminator`
87
+ """
88
+ def __init__(self, in_channels: int = 1, out_channels: int = 1,
89
+ periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90
+ super().__init__()
91
+ self.discriminators = nn.ModuleList([
92
+ PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93
+ ])
94
+
95
+ @property
96
+ def num_discriminators(self):
97
+ return len(self.discriminators)
98
+
99
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100
+ logits = []
101
+ fmaps = []
102
+ for disc in self.discriminators:
103
+ logit, fmap = disc(x)
104
+ logits.append(logit)
105
+ fmaps.append(fmap)
106
+ return logits, fmaps
audiocraft/adversarial/discriminators/msd.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from ...modules import NormConv1d
14
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15
+
16
+
17
+ class ScaleDiscriminator(nn.Module):
18
+ """Waveform sub-discriminator.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ kernel_sizes (Sequence[int]): Kernel sizes for first and last convolutions.
24
+ filters (int): Number of initial filters for convolutions.
25
+ max_filters (int): Maximum number of filters.
26
+ downsample_scales (Sequence[int]): Scale for downsampling implemented as strided convolutions.
27
+ inner_kernel_sizes (Sequence[int] or None): Kernel sizes for inner convolutions.
28
+ groups (Sequence[int] or None): Groups for inner convolutions.
29
+ strides (Sequence[int] or None): Strides for inner convolutions.
30
+ paddings (Sequence[int] or None): Paddings for inner convolutions.
31
+ norm (str): Normalization method.
32
+ activation (str): Activation function.
33
+ activation_params (dict): Parameters to provide to the activation function.
34
+ pad (str): Padding for initial convolution.
35
+ pad_params (dict): Parameters to provide to the padding module.
36
+ """
37
+ def __init__(self, in_channels=1, out_channels=1, kernel_sizes: tp.Sequence[int] = [5, 3],
38
+ filters: int = 16, max_filters: int = 1024, downsample_scales: tp.Sequence[int] = [4, 4, 4, 4],
39
+ inner_kernel_sizes: tp.Optional[tp.Sequence[int]] = None, groups: tp.Optional[tp.Sequence[int]] = None,
40
+ strides: tp.Optional[tp.Sequence[int]] = None, paddings: tp.Optional[tp.Sequence[int]] = None,
41
+ norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42
+ activation_params: dict = {'negative_slope': 0.2}, pad: str = 'ReflectionPad1d',
43
+ pad_params: dict = {}):
44
+ super().__init__()
45
+ assert len(kernel_sizes) == 2
46
+ assert kernel_sizes[0] % 2 == 1
47
+ assert kernel_sizes[1] % 2 == 1
48
+ assert (inner_kernel_sizes is None or len(inner_kernel_sizes) == len(downsample_scales))
49
+ assert (groups is None or len(groups) == len(downsample_scales))
50
+ assert (strides is None or len(strides) == len(downsample_scales))
51
+ assert (paddings is None or len(paddings) == len(downsample_scales))
52
+ self.activation = getattr(torch.nn, activation)(**activation_params)
53
+ self.convs = nn.ModuleList()
54
+ self.convs.append(
55
+ nn.Sequential(
56
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
57
+ NormConv1d(in_channels, filters, kernel_size=np.prod(kernel_sizes), stride=1, norm=norm)
58
+ )
59
+ )
60
+
61
+ in_chs = filters
62
+ for i, downsample_scale in enumerate(downsample_scales):
63
+ out_chs = min(in_chs * downsample_scale, max_filters)
64
+ default_kernel_size = downsample_scale * 10 + 1
65
+ default_stride = downsample_scale
66
+ default_padding = (default_kernel_size - 1) // 2
67
+ default_groups = in_chs // 4
68
+ self.convs.append(
69
+ NormConv1d(in_chs, out_chs,
70
+ kernel_size=inner_kernel_sizes[i] if inner_kernel_sizes else default_kernel_size,
71
+ stride=strides[i] if strides else default_stride,
72
+ groups=groups[i] if groups else default_groups,
73
+ padding=paddings[i] if paddings else default_padding,
74
+ norm=norm))
75
+ in_chs = out_chs
76
+
77
+ out_chs = min(in_chs * 2, max_filters)
78
+ self.convs.append(NormConv1d(in_chs, out_chs, kernel_size=kernel_sizes[0], stride=1,
79
+ padding=(kernel_sizes[0] - 1) // 2, norm=norm))
80
+ self.conv_post = NormConv1d(out_chs, out_channels, kernel_size=kernel_sizes[1], stride=1,
81
+ padding=(kernel_sizes[1] - 1) // 2, norm=norm)
82
+
83
+ def forward(self, x: torch.Tensor):
84
+ fmap = []
85
+ for layer in self.convs:
86
+ x = layer(x)
87
+ x = self.activation(x)
88
+ fmap.append(x)
89
+ x = self.conv_post(x)
90
+ fmap.append(x)
91
+ # x = torch.flatten(x, 1, -1)
92
+ return x, fmap
93
+
94
+
95
+ class MultiScaleDiscriminator(MultiDiscriminator):
96
+ """Multi-Scale (MSD) Discriminator,
97
+
98
+ Args:
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ downsample_factor (int): Downsampling factor between the different scales.
102
+ scale_norms (Sequence[str]): Normalization for each sub-discriminator.
103
+ **kwargs: Additional args for ScaleDiscriminator.
104
+ """
105
+ def __init__(self, in_channels: int = 1, out_channels: int = 1, downsample_factor: int = 2,
106
+ scale_norms: tp.Sequence[str] = ['weight_norm', 'weight_norm', 'weight_norm'], **kwargs):
107
+ super().__init__()
108
+ self.discriminators = nn.ModuleList([
109
+ ScaleDiscriminator(in_channels, out_channels, norm=norm, **kwargs) for norm in scale_norms
110
+ ])
111
+ self.downsample = nn.AvgPool1d(downsample_factor * 2, downsample_factor, padding=downsample_factor)
112
+
113
+ @property
114
+ def num_discriminators(self):
115
+ return len(self.discriminators)
116
+
117
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
118
+ logits = []
119
+ fmaps = []
120
+ for i, disc in enumerate(self.discriminators):
121
+ if i != 0:
122
+ self.downsample(x)
123
+ logit, fmap = disc(x)
124
+ logits.append(logit)
125
+ fmaps.append(fmap)
126
+ return logits, fmaps
audiocraft/adversarial/discriminators/msstftd.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import torchaudio
10
+ import torch
11
+ from torch import nn
12
+ from einops import rearrange
13
+
14
+ from ...modules import NormConv2d
15
+ from .base import MultiDiscriminator, MultiDiscriminatorOutputType
16
+
17
+
18
+ def get_2d_padding(kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)):
19
+ return (((kernel_size[0] - 1) * dilation[0]) // 2, ((kernel_size[1] - 1) * dilation[1]) // 2)
20
+
21
+
22
+ class DiscriminatorSTFT(nn.Module):
23
+ """STFT sub-discriminator.
24
+
25
+ Args:
26
+ filters (int): Number of filters in convolutions.
27
+ in_channels (int): Number of input channels.
28
+ out_channels (int): Number of output channels.
29
+ n_fft (int): Size of FFT for each scale.
30
+ hop_length (int): Length of hop between STFT windows for each scale.
31
+ kernel_size (tuple of int): Inner Conv2d kernel sizes.
32
+ stride (tuple of int): Inner Conv2d strides.
33
+ dilations (list of int): Inner Conv2d dilation on the time dimension.
34
+ win_length (int): Window size for each scale.
35
+ normalized (bool): Whether to normalize by magnitude after stft.
36
+ norm (str): Normalization method.
37
+ activation (str): Activation function.
38
+ activation_params (dict): Parameters to provide to the activation function.
39
+ growth (int): Growth factor for the filters.
40
+ """
41
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1,
42
+ n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, max_filters: int = 1024,
43
+ filters_scale: int = 1, kernel_size: tp.Tuple[int, int] = (3, 9), dilations: tp.List = [1, 2, 4],
44
+ stride: tp.Tuple[int, int] = (1, 2), normalized: bool = True, norm: str = 'weight_norm',
45
+ activation: str = 'LeakyReLU', activation_params: dict = {'negative_slope': 0.2}):
46
+ super().__init__()
47
+ assert len(kernel_size) == 2
48
+ assert len(stride) == 2
49
+ self.filters = filters
50
+ self.in_channels = in_channels
51
+ self.out_channels = out_channels
52
+ self.n_fft = n_fft
53
+ self.hop_length = hop_length
54
+ self.win_length = win_length
55
+ self.normalized = normalized
56
+ self.activation = getattr(torch.nn, activation)(**activation_params)
57
+ self.spec_transform = torchaudio.transforms.Spectrogram(
58
+ n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, window_fn=torch.hann_window,
59
+ normalized=self.normalized, center=False, pad_mode=None, power=None)
60
+ spec_channels = 2 * self.in_channels
61
+ self.convs = nn.ModuleList()
62
+ self.convs.append(
63
+ NormConv2d(spec_channels, self.filters, kernel_size=kernel_size, padding=get_2d_padding(kernel_size))
64
+ )
65
+ in_chs = min(filters_scale * self.filters, max_filters)
66
+ for i, dilation in enumerate(dilations):
67
+ out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
68
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride,
69
+ dilation=(dilation, 1), padding=get_2d_padding(kernel_size, (dilation, 1)),
70
+ norm=norm))
71
+ in_chs = out_chs
72
+ out_chs = min((filters_scale ** (len(dilations) + 1)) * self.filters, max_filters)
73
+ self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_size[0], kernel_size[0]),
74
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
75
+ norm=norm))
76
+ self.conv_post = NormConv2d(out_chs, self.out_channels,
77
+ kernel_size=(kernel_size[0], kernel_size[0]),
78
+ padding=get_2d_padding((kernel_size[0], kernel_size[0])),
79
+ norm=norm)
80
+
81
+ def forward(self, x: torch.Tensor):
82
+ fmap = []
83
+ z = self.spec_transform(x) # [B, 2, Freq, Frames, 2]
84
+ z = torch.cat([z.real, z.imag], dim=1)
85
+ z = rearrange(z, 'b c w t -> b c t w')
86
+ for i, layer in enumerate(self.convs):
87
+ z = layer(z)
88
+ z = self.activation(z)
89
+ fmap.append(z)
90
+ z = self.conv_post(z)
91
+ return z, fmap
92
+
93
+
94
+ class MultiScaleSTFTDiscriminator(MultiDiscriminator):
95
+ """Multi-Scale STFT (MS-STFT) discriminator.
96
+
97
+ Args:
98
+ filters (int): Number of filters in convolutions.
99
+ in_channels (int): Number of input channels.
100
+ out_channels (int): Number of output channels.
101
+ sep_channels (bool): Separate channels to distinct samples for stereo support.
102
+ n_ffts (Sequence[int]): Size of FFT for each scale.
103
+ hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale.
104
+ win_lengths (Sequence[int]): Window size for each scale.
105
+ **kwargs: Additional args for STFTDiscriminator.
106
+ """
107
+ def __init__(self, filters: int, in_channels: int = 1, out_channels: int = 1, sep_channels: bool = False,
108
+ n_ffts: tp.List[int] = [1024, 2048, 512], hop_lengths: tp.List[int] = [256, 512, 128],
109
+ win_lengths: tp.List[int] = [1024, 2048, 512], **kwargs):
110
+ super().__init__()
111
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
112
+ self.sep_channels = sep_channels
113
+ self.discriminators = nn.ModuleList([
114
+ DiscriminatorSTFT(filters, in_channels=in_channels, out_channels=out_channels,
115
+ n_fft=n_ffts[i], win_length=win_lengths[i], hop_length=hop_lengths[i], **kwargs)
116
+ for i in range(len(n_ffts))
117
+ ])
118
+
119
+ @property
120
+ def num_discriminators(self):
121
+ return len(self.discriminators)
122
+
123
+ def _separate_channels(self, x: torch.Tensor) -> torch.Tensor:
124
+ B, C, T = x.shape
125
+ return x.view(-1, 1, T)
126
+
127
+ def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
128
+ logits = []
129
+ fmaps = []
130
+ for disc in self.discriminators:
131
+ logit, fmap = disc(x)
132
+ logits.append(logit)
133
+ fmaps.append(fmap)
134
+ return logits, fmaps
audiocraft/adversarial/losses.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility module to handle adversarial losses without requiring to mess up the main training loop.
9
+ """
10
+
11
+ import typing as tp
12
+
13
+ import flashy
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ ADVERSARIAL_LOSSES = ['mse', 'hinge', 'hinge2']
20
+
21
+
22
+ AdvLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor], torch.Tensor]]
23
+ FeatLossType = tp.Union[nn.Module, tp.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
24
+
25
+
26
+ class AdversarialLoss(nn.Module):
27
+ """Adversary training wrapper.
28
+
29
+ Args:
30
+ adversary (nn.Module): The adversary module will be used to estimate the logits given the fake and real samples.
31
+ We assume here the adversary output is ``Tuple[List[torch.Tensor], List[List[torch.Tensor]]]``
32
+ where the first item is a list of logits and the second item is a list of feature maps.
33
+ optimizer (torch.optim.Optimizer): Optimizer used for training the given module.
34
+ loss (AdvLossType): Loss function for generator training.
35
+ loss_real (AdvLossType): Loss function for adversarial training on logits from real samples.
36
+ loss_fake (AdvLossType): Loss function for adversarial training on logits from fake samples.
37
+ loss_feat (FeatLossType): Feature matching loss function for generator training.
38
+ normalize (bool): Whether to normalize by number of sub-discriminators.
39
+
40
+ Example of usage:
41
+ adv_loss = AdversarialLoss(adversaries, optimizer, loss, loss_real, loss_fake)
42
+ for real in loader:
43
+ noise = torch.randn(...)
44
+ fake = model(noise)
45
+ adv_loss.train_adv(fake, real)
46
+ loss, _ = adv_loss(fake, real)
47
+ loss.backward()
48
+ """
49
+ def __init__(self,
50
+ adversary: nn.Module,
51
+ optimizer: torch.optim.Optimizer,
52
+ loss: AdvLossType,
53
+ loss_real: AdvLossType,
54
+ loss_fake: AdvLossType,
55
+ loss_feat: tp.Optional[FeatLossType] = None,
56
+ normalize: bool = True):
57
+ super().__init__()
58
+ self.adversary: nn.Module = adversary
59
+ flashy.distrib.broadcast_model(self.adversary)
60
+ self.optimizer = optimizer
61
+ self.loss = loss
62
+ self.loss_real = loss_real
63
+ self.loss_fake = loss_fake
64
+ self.loss_feat = loss_feat
65
+ self.normalize = normalize
66
+
67
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
68
+ # Add the optimizer state dict inside our own.
69
+ super()._save_to_state_dict(destination, prefix, keep_vars)
70
+ destination[prefix + 'optimizer'] = self.optimizer.state_dict()
71
+ return destination
72
+
73
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
74
+ # Load optimizer state.
75
+ self.optimizer.load_state_dict(state_dict.pop(prefix + 'optimizer'))
76
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
77
+
78
+ def get_adversary_pred(self, x):
79
+ """Run adversary model, validating expected output format."""
80
+ logits, fmaps = self.adversary(x)
81
+ assert isinstance(logits, list) and all([isinstance(t, torch.Tensor) for t in logits]), \
82
+ f'Expecting a list of tensors as logits but {type(logits)} found.'
83
+ assert isinstance(fmaps, list), f'Expecting a list of features maps but {type(fmaps)} found.'
84
+ for fmap in fmaps:
85
+ assert isinstance(fmap, list) and all([isinstance(f, torch.Tensor) for f in fmap]), \
86
+ f'Expecting a list of tensors as feature maps but {type(fmap)} found.'
87
+ return logits, fmaps
88
+
89
+ def train_adv(self, fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
90
+ """Train the adversary with the given fake and real example.
91
+
92
+ We assume the adversary output is the following format: Tuple[List[torch.Tensor], List[List[torch.Tensor]]].
93
+ The first item being the logits and second item being a list of feature maps for each sub-discriminator.
94
+
95
+ This will automatically synchronize gradients (with `flashy.distrib.eager_sync_model`)
96
+ and call the optimizer.
97
+ """
98
+ loss = torch.tensor(0., device=fake.device)
99
+ all_logits_fake_is_fake, _ = self.get_adversary_pred(fake.detach())
100
+ all_logits_real_is_fake, _ = self.get_adversary_pred(real.detach())
101
+ n_sub_adversaries = len(all_logits_fake_is_fake)
102
+ for logit_fake_is_fake, logit_real_is_fake in zip(all_logits_fake_is_fake, all_logits_real_is_fake):
103
+ loss += self.loss_fake(logit_fake_is_fake) + self.loss_real(logit_real_is_fake)
104
+
105
+ if self.normalize:
106
+ loss /= n_sub_adversaries
107
+
108
+ self.optimizer.zero_grad()
109
+ with flashy.distrib.eager_sync_model(self.adversary):
110
+ loss.backward()
111
+ self.optimizer.step()
112
+
113
+ return loss
114
+
115
+ def forward(self, fake: torch.Tensor, real: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
116
+ """Return the loss for the generator, i.e. trying to fool the adversary,
117
+ and feature matching loss if provided.
118
+ """
119
+ adv = torch.tensor(0., device=fake.device)
120
+ feat = torch.tensor(0., device=fake.device)
121
+ with flashy.utils.readonly(self.adversary):
122
+ all_logits_fake_is_fake, all_fmap_fake = self.get_adversary_pred(fake)
123
+ all_logits_real_is_fake, all_fmap_real = self.get_adversary_pred(real)
124
+ n_sub_adversaries = len(all_logits_fake_is_fake)
125
+ for logit_fake_is_fake in all_logits_fake_is_fake:
126
+ adv += self.loss(logit_fake_is_fake)
127
+ if self.loss_feat:
128
+ for fmap_fake, fmap_real in zip(all_fmap_fake, all_fmap_real):
129
+ feat += self.loss_feat(fmap_fake, fmap_real)
130
+
131
+ if self.normalize:
132
+ adv /= n_sub_adversaries
133
+ feat /= n_sub_adversaries
134
+
135
+ return adv, feat
136
+
137
+
138
+ def get_adv_criterion(loss_type: str) -> tp.Callable:
139
+ assert loss_type in ADVERSARIAL_LOSSES
140
+ if loss_type == 'mse':
141
+ return mse_loss
142
+ elif loss_type == 'hinge':
143
+ return hinge_loss
144
+ elif loss_type == 'hinge2':
145
+ return hinge2_loss
146
+ raise ValueError('Unsupported loss')
147
+
148
+
149
+ def get_fake_criterion(loss_type: str) -> tp.Callable:
150
+ assert loss_type in ADVERSARIAL_LOSSES
151
+ if loss_type == 'mse':
152
+ return mse_fake_loss
153
+ elif loss_type in ['hinge', 'hinge2']:
154
+ return hinge_fake_loss
155
+ raise ValueError('Unsupported loss')
156
+
157
+
158
+ def get_real_criterion(loss_type: str) -> tp.Callable:
159
+ assert loss_type in ADVERSARIAL_LOSSES
160
+ if loss_type == 'mse':
161
+ return mse_real_loss
162
+ elif loss_type in ['hinge', 'hinge2']:
163
+ return hinge_real_loss
164
+ raise ValueError('Unsupported loss')
165
+
166
+
167
+ def mse_real_loss(x: torch.Tensor) -> torch.Tensor:
168
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
169
+
170
+
171
+ def mse_fake_loss(x: torch.Tensor) -> torch.Tensor:
172
+ return F.mse_loss(x, torch.tensor(0., device=x.device).expand_as(x))
173
+
174
+
175
+ def hinge_real_loss(x: torch.Tensor) -> torch.Tensor:
176
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
177
+
178
+
179
+ def hinge_fake_loss(x: torch.Tensor) -> torch.Tensor:
180
+ return -torch.mean(torch.min(-x - 1, torch.tensor(0., device=x.device).expand_as(x)))
181
+
182
+
183
+ def mse_loss(x: torch.Tensor) -> torch.Tensor:
184
+ if x.numel() == 0:
185
+ return torch.tensor([0.0], device=x.device)
186
+ return F.mse_loss(x, torch.tensor(1., device=x.device).expand_as(x))
187
+
188
+
189
+ def hinge_loss(x: torch.Tensor) -> torch.Tensor:
190
+ if x.numel() == 0:
191
+ return torch.tensor([0.0], device=x.device)
192
+ return -x.mean()
193
+
194
+
195
+ def hinge2_loss(x: torch.Tensor) -> torch.Tensor:
196
+ if x.numel() == 0:
197
+ return torch.tensor([0.0])
198
+ return -torch.mean(torch.min(x - 1, torch.tensor(0., device=x.device).expand_as(x)))
199
+
200
+
201
+ class FeatureMatchingLoss(nn.Module):
202
+ """Feature matching loss for adversarial training.
203
+
204
+ Args:
205
+ loss (nn.Module): Loss to use for feature matching (default=torch.nn.L1).
206
+ normalize (bool): Whether to normalize the loss.
207
+ by number of feature maps.
208
+ """
209
+ def __init__(self, loss: nn.Module = torch.nn.L1Loss(), normalize: bool = True):
210
+ super().__init__()
211
+ self.loss = loss
212
+ self.normalize = normalize
213
+
214
+ def forward(self, fmap_fake: tp.List[torch.Tensor], fmap_real: tp.List[torch.Tensor]) -> torch.Tensor:
215
+ assert len(fmap_fake) == len(fmap_real) and len(fmap_fake) > 0
216
+ feat_loss = torch.tensor(0., device=fmap_fake[0].device)
217
+ feat_scale = torch.tensor(0., device=fmap_fake[0].device)
218
+ n_fmaps = 0
219
+ for (feat_fake, feat_real) in zip(fmap_fake, fmap_real):
220
+ assert feat_fake.shape == feat_real.shape
221
+ n_fmaps += 1
222
+ feat_loss += self.loss(feat_fake, feat_real)
223
+ feat_scale += torch.mean(torch.abs(feat_real))
224
+
225
+ if self.normalize:
226
+ feat_loss /= n_fmaps
227
+
228
+ return feat_loss
audiocraft/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Audio loading and writing support. Datasets for raw audio
7
+ or also including some metadata."""
8
+
9
+ # flake8: noqa
10
+ from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
audiocraft/data/audio.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Audio IO methods are defined in this module (info, read, write),
9
+ We rely on av library for faster read when possible, otherwise on torchaudio.
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ import logging
15
+ import typing as tp
16
+
17
+ import numpy as np
18
+ import soundfile
19
+ import torch
20
+ from torch.nn import functional as F
21
+ import torchaudio as ta
22
+
23
+ import av
24
+
25
+ from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26
+
27
+
28
+ _av_initialized = False
29
+
30
+
31
+ def _init_av():
32
+ global _av_initialized
33
+ if _av_initialized:
34
+ return
35
+ logger = logging.getLogger('libav.mp3')
36
+ logger.setLevel(logging.ERROR)
37
+ _av_initialized = True
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class AudioFileInfo:
42
+ sample_rate: int
43
+ duration: float
44
+ channels: int
45
+
46
+
47
+ def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48
+ _init_av()
49
+ with av.open(str(filepath)) as af:
50
+ stream = af.streams.audio[0]
51
+ sample_rate = stream.codec_context.sample_rate
52
+ duration = float(stream.duration * stream.time_base)
53
+ channels = stream.channels
54
+ return AudioFileInfo(sample_rate, duration, channels)
55
+
56
+
57
+ def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58
+ info = soundfile.info(filepath)
59
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
60
+
61
+
62
+ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
64
+ filepath = Path(filepath)
65
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66
+ # ffmpeg has some weird issue with flac.
67
+ return _soundfile_info(filepath)
68
+ else:
69
+ return _av_info(filepath)
70
+
71
+
72
+ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
+ """FFMPEG-based audio file reading using PyAV bindings.
74
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
+
76
+ Args:
77
+ filepath (str or Path): Path to audio file to read.
78
+ seek_time (float): Time at which to start reading in the file.
79
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
80
+ Returns:
81
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate
82
+ """
83
+ _init_av()
84
+ with av.open(str(filepath)) as af:
85
+ stream = af.streams.audio[0]
86
+ sr = stream.codec_context.sample_rate
87
+ num_frames = int(sr * duration) if duration >= 0 else -1
88
+ frame_offset = int(sr * seek_time)
89
+ # we need a small negative offset otherwise we get some edge artifact
90
+ # from the mp3 decoder.
91
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92
+ frames = []
93
+ length = 0
94
+ for frame in af.decode(streams=stream.index):
95
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
96
+ strip = max(0, frame_offset - current_offset)
97
+ buf = torch.from_numpy(frame.to_ndarray())
98
+ if buf.shape[0] != stream.channels:
99
+ buf = buf.view(-1, stream.channels).t()
100
+ buf = buf[:, strip:]
101
+ frames.append(buf)
102
+ length += buf.shape[1]
103
+ if num_frames > 0 and length >= num_frames:
104
+ break
105
+ assert frames
106
+ # If the above assert fails, it is likely because we seeked past the end of file point,
107
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108
+ # This will need proper debugging, in due time.
109
+ wav = torch.cat(frames, dim=1)
110
+ assert wav.shape[0] == stream.channels
111
+ if num_frames > 0:
112
+ wav = wav[:, :num_frames]
113
+ return f32_pcm(wav), sr
114
+
115
+
116
+ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
+ """Read audio by picking the most appropriate backend tool based on the audio format.
119
+
120
+ Args:
121
+ filepath (str or Path): Path to audio file to read.
122
+ seek_time (float): Time at which to start reading in the file.
123
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
124
+ pad (bool): Pad output audio if not reaching expected duration.
125
+ Returns:
126
+ tuple of torch.Tensor, int: Tuple containing audio data and sample rate.
127
+ """
128
+ fp = Path(filepath)
129
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130
+ # There is some bug with ffmpeg and reading flac
131
+ info = _soundfile_info(filepath)
132
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133
+ frame_offset = int(seek_time * info.sample_rate)
134
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136
+ wav = torch.from_numpy(wav).t().contiguous()
137
+ if len(wav.shape) == 1:
138
+ wav = torch.unsqueeze(wav, 0)
139
+ elif (
140
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141
+ and duration <= 0 and seek_time == 0
142
+ ):
143
+ # Torchaudio is faster if we load an entire file at once.
144
+ wav, sr = ta.load(fp)
145
+ else:
146
+ wav, sr = _av_read(filepath, seek_time, duration)
147
+ if pad and duration > 0:
148
+ expected_frames = int(duration * sr)
149
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150
+ return wav, sr
151
+
152
+
153
+ def audio_write(stem_name: tp.Union[str, Path],
154
+ wav: torch.Tensor, sample_rate: int,
155
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
+ loudness_compressor: bool = False,
159
+ log_clipping: bool = True, make_parent_dir: bool = True,
160
+ add_suffix: bool = True) -> Path:
161
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
162
+
163
+ Args:
164
+ stem_name (str or Path): Filename without extension which will be added automatically.
165
+ format (str): Either "wav" or "mp3".
166
+ mp3_rate (int): kbps when using mp3s.
167
+ normalize (bool): if `True` (default), normalizes according to the prescribed
168
+ strategy (see after). If `False`, the strategy is only used in case clipping
169
+ would happen.
170
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
171
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
172
+ with extra headroom to avoid clipping. 'clip' just clips.
173
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
174
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
175
+ than the `peak_clip` one to avoid further clipping.
176
+ loudness_headroom_db (float): Target loudness for loudness normalization.
177
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
178
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
179
+ occurs despite strategy (only for 'rms').
180
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
181
+ Returns:
182
+ Path: Path of the saved audio.
183
+ """
184
+ assert wav.dtype.is_floating_point, "wav is not floating point"
185
+ if wav.dim() == 1:
186
+ wav = wav[None]
187
+ elif wav.dim() > 2:
188
+ raise ValueError("Input wav should be at most 2 dimension.")
189
+ assert wav.isfinite().all()
190
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
191
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
192
+ log_clipping=log_clipping, sample_rate=sample_rate,
193
+ stem_name=str(stem_name))
194
+ kwargs: dict = {}
195
+ if format == 'mp3':
196
+ suffix = '.mp3'
197
+ kwargs.update({"compression": mp3_rate})
198
+ elif format == 'wav':
199
+ wav = i16_pcm(wav)
200
+ suffix = '.wav'
201
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
202
+ else:
203
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
204
+ if not add_suffix:
205
+ suffix = ''
206
+ path = Path(str(stem_name) + suffix)
207
+ if make_parent_dir:
208
+ path.parent.mkdir(exist_ok=True, parents=True)
209
+ try:
210
+ ta.save(path, wav, sample_rate, **kwargs)
211
+ except Exception:
212
+ if path.exists():
213
+ # we do not want to leave half written files around.
214
+ path.unlink()
215
+ raise
216
+ return path
audiocraft/data/audio_dataset.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioDataset support. In order to handle a larger number of files
7
+ without having to scan again the folders, we precompute some metadata
8
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
+ """
10
+ import argparse
11
+ import copy
12
+ from concurrent.futures import ThreadPoolExecutor, Future
13
+ from dataclasses import dataclass, fields
14
+ from contextlib import ExitStack
15
+ from functools import lru_cache
16
+ import gzip
17
+ import json
18
+ import logging
19
+ import os
20
+ from pathlib import Path
21
+ import random
22
+ import sys
23
+ import typing as tp
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+
28
+ from .audio import audio_read, audio_info
29
+ from .audio_utils import convert_audio
30
+ from .zip import PathInZip
31
+
32
+ try:
33
+ import dora
34
+ except ImportError:
35
+ dora = None # type: ignore
36
+
37
+
38
+ @dataclass(order=True)
39
+ class BaseInfo:
40
+
41
+ @classmethod
42
+ def _dict2fields(cls, dictionary: dict):
43
+ return {
44
+ field.name: dictionary[field.name]
45
+ for field in fields(cls) if field.name in dictionary
46
+ }
47
+
48
+ @classmethod
49
+ def from_dict(cls, dictionary: dict):
50
+ _dictionary = cls._dict2fields(dictionary)
51
+ return cls(**_dictionary)
52
+
53
+ def to_dict(self):
54
+ return {
55
+ field.name: self.__getattribute__(field.name)
56
+ for field in fields(self)
57
+ }
58
+
59
+
60
+ @dataclass(order=True)
61
+ class AudioMeta(BaseInfo):
62
+ path: str
63
+ duration: float
64
+ sample_rate: int
65
+ amplitude: tp.Optional[float] = None
66
+ weight: tp.Optional[float] = None
67
+ # info_path is used to load additional information about the audio file that is stored in zip files.
68
+ info_path: tp.Optional[PathInZip] = None
69
+
70
+ @classmethod
71
+ def from_dict(cls, dictionary: dict):
72
+ base = cls._dict2fields(dictionary)
73
+ if 'info_path' in base and base['info_path'] is not None:
74
+ base['info_path'] = PathInZip(base['info_path'])
75
+ return cls(**base)
76
+
77
+ def to_dict(self):
78
+ d = super().to_dict()
79
+ if d['info_path'] is not None:
80
+ d['info_path'] = str(d['info_path'])
81
+ return d
82
+
83
+
84
+ @dataclass(order=True)
85
+ class SegmentInfo(BaseInfo):
86
+ meta: AudioMeta
87
+ seek_time: float
88
+ # The following values are given once the audio is processed, e.g.
89
+ # at the target sample rate and target number of channels.
90
+ n_frames: int # actual number of frames without padding
91
+ total_frames: int # total number of frames, padding included
92
+ sample_rate: int # actual sample rate
93
+ channels: int # number of audio channels.
94
+
95
+
96
+ DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
97
+
98
+ logger = logging.getLogger(__name__)
99
+
100
+
101
+ def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
102
+ """AudioMeta from a path to an audio file.
103
+
104
+ Args:
105
+ file_path (str): Resolved path of valid audio file.
106
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
107
+ Returns:
108
+ AudioMeta: Audio file path and its metadata.
109
+ """
110
+ info = audio_info(file_path)
111
+ amplitude: tp.Optional[float] = None
112
+ if not minimal:
113
+ wav, sr = audio_read(file_path)
114
+ amplitude = wav.abs().max().item()
115
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
116
+
117
+
118
+ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
119
+ """If Dora is available as a dependency, try to resolve potential relative paths
120
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
121
+
122
+ Args:
123
+ m (AudioMeta): Audio meta to resolve.
124
+ fast (bool): If True, uses a really fast check for determining if a file
125
+ is already absolute or not. Only valid on Linux/Mac.
126
+ Returns:
127
+ AudioMeta: Audio meta with resolved path.
128
+ """
129
+ def is_abs(m):
130
+ if fast:
131
+ return str(m)[0] == '/'
132
+ else:
133
+ os.path.isabs(str(m))
134
+
135
+ if not dora:
136
+ return m
137
+
138
+ if not is_abs(m.path):
139
+ m.path = dora.git_save.to_absolute_path(m.path)
140
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
141
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
142
+ return m
143
+
144
+
145
+ def find_audio_files(path: tp.Union[Path, str],
146
+ exts: tp.List[str] = DEFAULT_EXTS,
147
+ resolve: bool = True,
148
+ minimal: bool = True,
149
+ progress: bool = False,
150
+ workers: int = 0) -> tp.List[AudioMeta]:
151
+ """Build a list of AudioMeta from a given path,
152
+ collecting relevant audio files and fetching meta info.
153
+
154
+ Args:
155
+ path (str or Path): Path to folder containing audio files.
156
+ exts (list of str): List of file extensions to consider for audio files.
157
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
158
+ progress (bool): Whether to log progress on audio files collection.
159
+ workers (int): number of parallel workers, if 0, use only the current thread.
160
+ Returns:
161
+ list of AudioMeta: List of audio file path and its metadata.
162
+ """
163
+ audio_files = []
164
+ futures: tp.List[Future] = []
165
+ pool: tp.Optional[ThreadPoolExecutor] = None
166
+ with ExitStack() as stack:
167
+ if workers > 0:
168
+ pool = ThreadPoolExecutor(workers)
169
+ stack.enter_context(pool)
170
+
171
+ if progress:
172
+ print("Finding audio files...")
173
+ for root, folders, files in os.walk(path, followlinks=True):
174
+ for file in files:
175
+ full_path = Path(root) / file
176
+ if full_path.suffix.lower() in exts:
177
+ audio_files.append(full_path)
178
+ if pool is not None:
179
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
180
+ if progress:
181
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
182
+
183
+ if progress:
184
+ print("Getting audio metadata...")
185
+ meta: tp.List[AudioMeta] = []
186
+ for idx, file_path in enumerate(audio_files):
187
+ try:
188
+ if pool is None:
189
+ m = _get_audio_meta(str(file_path), minimal)
190
+ else:
191
+ m = futures[idx].result()
192
+ if resolve:
193
+ m = _resolve_audio_meta(m)
194
+ except Exception as err:
195
+ print("Error with", str(file_path), err, file=sys.stderr)
196
+ continue
197
+ meta.append(m)
198
+ if progress:
199
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
200
+ meta.sort()
201
+ return meta
202
+
203
+
204
+ def load_audio_meta(path: tp.Union[str, Path],
205
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
206
+ """Load list of AudioMeta from an optionally compressed json file.
207
+
208
+ Args:
209
+ path (str or Path): Path to JSON file.
210
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
211
+ fast (bool): activates some tricks to make things faster.
212
+ Returns:
213
+ list of AudioMeta: List of audio file path and its total duration.
214
+ """
215
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
216
+ with open_fn(path, 'rb') as fp: # type: ignore
217
+ lines = fp.readlines()
218
+ meta = []
219
+ for line in lines:
220
+ d = json.loads(line)
221
+ m = AudioMeta.from_dict(d)
222
+ if resolve:
223
+ m = _resolve_audio_meta(m, fast=fast)
224
+ meta.append(m)
225
+ return meta
226
+
227
+
228
+ def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
229
+ """Save the audio metadata to the file pointer as json.
230
+
231
+ Args:
232
+ path (str or Path): Path to JSON file.
233
+ metadata (list of BaseAudioMeta): List of audio meta to save.
234
+ """
235
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
236
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
237
+ with open_fn(path, 'wb') as fp: # type: ignore
238
+ for m in meta:
239
+ json_str = json.dumps(m.to_dict()) + '\n'
240
+ json_bytes = json_str.encode('utf-8')
241
+ fp.write(json_bytes)
242
+
243
+
244
+ class AudioDataset:
245
+ """Base audio dataset.
246
+
247
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
248
+ and potentially additional information, by creating random segments from the list of audio
249
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
250
+ mixing of channels, padding, etc.
251
+
252
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
253
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
254
+ duration, applying padding if required.
255
+
256
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
257
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
258
+ original audio meta.
259
+
260
+ Note that you can call `start_epoch(epoch)` in order to get
261
+ a deterministic "randomization" for `shuffle=True`.
262
+ For a given epoch and dataset index, this will always return the same extract.
263
+ You can get back some diversity by setting the `shuffle_seed` param.
264
+
265
+ Args:
266
+ meta (list of AudioMeta): List of audio files metadata.
267
+ segment_duration (float, optional): Optional segment duration of audio to load.
268
+ If not specified, the dataset will load the full audio segment from the file.
269
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
270
+ sample_rate (int): Target sample rate of the loaded audio samples.
271
+ channels (int): Target number of channels of the loaded audio samples.
272
+ sample_on_duration (bool): Set to `True` to sample segments with probability
273
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
274
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
275
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
276
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
277
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
278
+ is shorter than the desired segment.
279
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
280
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
281
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
282
+ audio shorter than this will be filtered out.
283
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
284
+ audio longer than this will be filtered out.
285
+ shuffle_seed (int): can be used to further randomize
286
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
287
+ with the expected segment_duration (which must be provided if load_wav is False).
288
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
289
+ are False. Will ensure a permutation on files when going through the dataset.
290
+ In that case the epoch number must be provided in order for the model
291
+ to continue the permutation across epochs. In that case, it is assumed
292
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
293
+ `total_batch_size` the overall batch size accounting for all gpus.
294
+ """
295
+ def __init__(self,
296
+ meta: tp.List[AudioMeta],
297
+ segment_duration: tp.Optional[float] = None,
298
+ shuffle: bool = True,
299
+ num_samples: int = 10_000,
300
+ sample_rate: int = 48_000,
301
+ channels: int = 2,
302
+ pad: bool = True,
303
+ sample_on_duration: bool = True,
304
+ sample_on_weight: bool = True,
305
+ min_segment_ratio: float = 0.5,
306
+ max_read_retry: int = 10,
307
+ return_info: bool = False,
308
+ min_audio_duration: tp.Optional[float] = None,
309
+ max_audio_duration: tp.Optional[float] = None,
310
+ shuffle_seed: int = 0,
311
+ load_wav: bool = True,
312
+ permutation_on_files: bool = False,
313
+ ):
314
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
315
+ assert segment_duration is None or segment_duration > 0
316
+ assert segment_duration is None or min_segment_ratio >= 0
317
+ self.segment_duration = segment_duration
318
+ self.min_segment_ratio = min_segment_ratio
319
+ self.max_audio_duration = max_audio_duration
320
+ self.min_audio_duration = min_audio_duration
321
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
322
+ assert self.min_audio_duration <= self.max_audio_duration
323
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
324
+ assert len(self.meta) # Fail fast if all data has been filtered.
325
+ self.total_duration = sum(d.duration for d in self.meta)
326
+
327
+ if segment_duration is None:
328
+ num_samples = len(self.meta)
329
+ self.num_samples = num_samples
330
+ self.shuffle = shuffle
331
+ self.sample_rate = sample_rate
332
+ self.channels = channels
333
+ self.pad = pad
334
+ self.sample_on_weight = sample_on_weight
335
+ self.sample_on_duration = sample_on_duration
336
+ self.sampling_probabilities = self._get_sampling_probabilities()
337
+ self.max_read_retry = max_read_retry
338
+ self.return_info = return_info
339
+ self.shuffle_seed = shuffle_seed
340
+ self.current_epoch: tp.Optional[int] = None
341
+ self.load_wav = load_wav
342
+ if not load_wav:
343
+ assert segment_duration is not None
344
+ self.permutation_on_files = permutation_on_files
345
+ if permutation_on_files:
346
+ assert not self.sample_on_duration
347
+ assert not self.sample_on_weight
348
+ assert self.shuffle
349
+
350
+ def start_epoch(self, epoch: int):
351
+ self.current_epoch = epoch
352
+
353
+ def __len__(self):
354
+ return self.num_samples
355
+
356
+ def _get_sampling_probabilities(self, normalized: bool = True):
357
+ """Return the sampling probabilities for each file inside `self.meta`."""
358
+ scores: tp.List[float] = []
359
+ for file_meta in self.meta:
360
+ score = 1.
361
+ if self.sample_on_weight and file_meta.weight is not None:
362
+ score *= file_meta.weight
363
+ if self.sample_on_duration:
364
+ score *= file_meta.duration
365
+ scores.append(score)
366
+ probabilities = torch.tensor(scores)
367
+ if normalized:
368
+ probabilities /= probabilities.sum()
369
+ return probabilities
370
+
371
+ @staticmethod
372
+ @lru_cache(16)
373
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
374
+ # Used to keep the most recent files permutation in memory implicitely.
375
+ # will work unless someone is using a lot of Datasets in parallel.
376
+ rng = torch.Generator()
377
+ rng.manual_seed(base_seed + permutation_index)
378
+ return torch.randperm(num_files, generator=rng)
379
+
380
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
381
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
382
+ This is only called if `segment_duration` is not None.
383
+
384
+ You must use the provided random number generator `rng` for reproducibility.
385
+ You can further make use of the index accessed.
386
+ """
387
+ if self.permutation_on_files:
388
+ assert self.current_epoch is not None
389
+ total_index = self.current_epoch * len(self) + index
390
+ permutation_index = total_index // len(self.meta)
391
+ relative_index = total_index % len(self.meta)
392
+ permutation = AudioDataset._get_file_permutation(
393
+ len(self.meta), permutation_index, self.shuffle_seed)
394
+ file_index = permutation[relative_index]
395
+ return self.meta[file_index]
396
+
397
+ if not self.sample_on_weight and not self.sample_on_duration:
398
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
399
+ else:
400
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
401
+
402
+ return self.meta[file_index]
403
+
404
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
405
+ # Override this method in subclass if needed.
406
+ if self.load_wav:
407
+ return audio_read(path, seek_time, duration, pad=False)
408
+ else:
409
+ assert self.segment_duration is not None
410
+ n_frames = int(self.sample_rate * self.segment_duration)
411
+ return torch.zeros(self.channels, n_frames), self.sample_rate
412
+
413
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
414
+ if self.segment_duration is None:
415
+ file_meta = self.meta[index]
416
+ out, sr = audio_read(file_meta.path)
417
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
418
+ n_frames = out.shape[-1]
419
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
420
+ sample_rate=self.sample_rate, channels=out.shape[0])
421
+ else:
422
+ rng = torch.Generator()
423
+ if self.shuffle:
424
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
425
+ # otherwise we make use of the epoch number and optional shuffle_seed.
426
+ if self.current_epoch is None:
427
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
428
+ else:
429
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
430
+ else:
431
+ # We only use index
432
+ rng.manual_seed(index)
433
+
434
+ for retry in range(self.max_read_retry):
435
+ file_meta = self.sample_file(index, rng)
436
+ # We add some variance in the file position even if audio file is smaller than segment
437
+ # without ending up with empty segments
438
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
439
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
440
+ try:
441
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
442
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
443
+ n_frames = out.shape[-1]
444
+ target_frames = int(self.segment_duration * self.sample_rate)
445
+ if self.pad:
446
+ out = F.pad(out, (0, target_frames - n_frames))
447
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
448
+ sample_rate=self.sample_rate, channels=out.shape[0])
449
+ except Exception as exc:
450
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
451
+ if retry == self.max_read_retry - 1:
452
+ raise
453
+ else:
454
+ break
455
+
456
+ if self.return_info:
457
+ # Returns the wav and additional information on the wave segment
458
+ return out, segment_info
459
+ else:
460
+ return out
461
+
462
+ def collater(self, samples):
463
+ """The collater function has to be provided to the dataloader
464
+ if AudioDataset has return_info=True in order to properly collate
465
+ the samples of a batch.
466
+ """
467
+ if self.segment_duration is None and len(samples) > 1:
468
+ assert self.pad, "Must allow padding when batching examples of different durations."
469
+
470
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
471
+ to_pad = self.segment_duration is None and self.pad
472
+ if to_pad:
473
+ max_len = max([wav.shape[-1] for wav, _ in samples])
474
+
475
+ def _pad_wav(wav):
476
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
477
+
478
+ if self.return_info:
479
+ if len(samples) > 0:
480
+ assert len(samples[0]) == 2
481
+ assert isinstance(samples[0][0], torch.Tensor)
482
+ assert isinstance(samples[0][1], SegmentInfo)
483
+
484
+ wavs = [wav for wav, _ in samples]
485
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
486
+
487
+ if to_pad:
488
+ # Each wav could be of a different duration as they are not segmented.
489
+ for i in range(len(samples)):
490
+ # Determines the total length of the signal with padding, so we update here as we pad.
491
+ segment_infos[i].total_frames = max_len
492
+ wavs[i] = _pad_wav(wavs[i])
493
+
494
+ wav = torch.stack(wavs)
495
+ return wav, segment_infos
496
+ else:
497
+ assert isinstance(samples[0], torch.Tensor)
498
+ if to_pad:
499
+ samples = [_pad_wav(s) for s in samples]
500
+ return torch.stack(samples)
501
+
502
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
503
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
504
+ orig_len = len(meta)
505
+
506
+ # Filter data that is too short.
507
+ if self.min_audio_duration is not None:
508
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
509
+
510
+ # Filter data that is too long.
511
+ if self.max_audio_duration is not None:
512
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
513
+
514
+ filtered_len = len(meta)
515
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
516
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
517
+ if removed_percentage < 10:
518
+ logging.debug(msg)
519
+ else:
520
+ logging.warning(msg)
521
+ return meta
522
+
523
+ @classmethod
524
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
525
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
526
+
527
+ Args:
528
+ root (str or Path): Path to root folder containing audio files.
529
+ kwargs: Additional keyword arguments for the AudioDataset.
530
+ """
531
+ root = Path(root)
532
+ if root.is_dir():
533
+ if (root / 'data.jsonl').exists():
534
+ root = root / 'data.jsonl'
535
+ elif (root / 'data.jsonl.gz').exists():
536
+ root = root / 'data.jsonl.gz'
537
+ else:
538
+ raise ValueError("Don't know where to read metadata from in the dir. "
539
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
540
+ meta = load_audio_meta(root)
541
+ return cls(meta, **kwargs)
542
+
543
+ @classmethod
544
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
545
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
546
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
547
+
548
+ Args:
549
+ root (str or Path): Path to root folder containing audio files.
550
+ minimal_meta (bool): Whether to only load minimal metadata or not.
551
+ exts (list of str): Extensions for audio files.
552
+ kwargs: Additional keyword arguments for the AudioDataset.
553
+ """
554
+ root = Path(root)
555
+ if root.is_file():
556
+ meta = load_audio_meta(root, resolve=True)
557
+ else:
558
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
559
+ return cls(meta, **kwargs)
560
+
561
+
562
+ def main():
563
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
564
+ parser = argparse.ArgumentParser(
565
+ prog='audio_dataset',
566
+ description='Generate .jsonl files by scanning a folder.')
567
+ parser.add_argument('root', help='Root folder with all the audio files')
568
+ parser.add_argument('output_meta_file',
569
+ help='Output file to store the metadata, ')
570
+ parser.add_argument('--complete',
571
+ action='store_false', dest='minimal', default=True,
572
+ help='Retrieve all metadata, even the one that are expansive '
573
+ 'to compute (e.g. normalization).')
574
+ parser.add_argument('--resolve',
575
+ action='store_true', default=False,
576
+ help='Resolve the paths to be absolute and with no symlinks.')
577
+ parser.add_argument('--workers',
578
+ default=10, type=int,
579
+ help='Number of workers.')
580
+ args = parser.parse_args()
581
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
582
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
583
+ save_audio_meta(args.output_meta_file, meta)
584
+
585
+
586
+ if __name__ == '__main__':
587
+ main()
audiocraft/data/audio_utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
+ import sys
9
+ import typing as tp
10
+
11
+ import julius
12
+ import torch
13
+ import torchaudio
14
+
15
+
16
+ def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
17
+ """Convert audio to the given number of channels.
18
+
19
+ Args:
20
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
21
+ channels (int): Expected number of channels as output.
22
+ Returns:
23
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
24
+ """
25
+ *shape, src_channels, length = wav.shape
26
+ if src_channels == channels:
27
+ pass
28
+ elif channels == 1:
29
+ # Case 1:
30
+ # The caller asked 1-channel audio, and the stream has multiple
31
+ # channels, downmix all channels.
32
+ wav = wav.mean(dim=-2, keepdim=True)
33
+ elif src_channels == 1:
34
+ # Case 2:
35
+ # The caller asked for multiple channels, but the input file has
36
+ # a single channel, replicate the audio over all channels.
37
+ wav = wav.expand(*shape, channels, length)
38
+ elif src_channels >= channels:
39
+ # Case 3:
40
+ # The caller asked for multiple channels, and the input file has
41
+ # more channels than requested. In that case return the first channels.
42
+ wav = wav[..., :channels, :]
43
+ else:
44
+ # Case 4: What is a reasonable choice here?
45
+ raise ValueError('The audio file has less channels than requested but is not mono.')
46
+ return wav
47
+
48
+
49
+ def convert_audio(wav: torch.Tensor, from_rate: float,
50
+ to_rate: float, to_channels: int) -> torch.Tensor:
51
+ """Convert audio to new sample rate and number of audio channels."""
52
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
+ wav = convert_audio_channels(wav, to_channels)
54
+ return wav
55
+
56
+
57
+ def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
59
+ """Normalize an input signal to a user loudness in dB LKFS.
60
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
+
62
+ Args:
63
+ wav (torch.Tensor): Input multichannel audio data.
64
+ sample_rate (int): Sample rate.
65
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
+ loudness_compressor (bool): Uses tanh for soft clipping.
67
+ energy_floor (float): anything below that RMS level will not be rescaled.
68
+ Returns:
69
+ torch.Tensor: Loudness normalized output data.
70
+ """
71
+ energy = wav.pow(2).mean().sqrt().item()
72
+ if energy < energy_floor:
73
+ return wav
74
+ transform = torchaudio.transforms.Loudness(sample_rate)
75
+ input_loudness_db = transform(wav).item()
76
+ # calculate the gain needed to scale to the desired loudness level
77
+ delta_loudness = -loudness_headroom_db - input_loudness_db
78
+ gain = 10.0 ** (delta_loudness / 20.0)
79
+ output = gain * wav
80
+ if loudness_compressor:
81
+ output = torch.tanh(output)
82
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83
+ return output
84
+
85
+
86
+ def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87
+ """Utility function to clip the audio with logging if specified."""
88
+ max_scale = wav.abs().max()
89
+ if log_clipping and max_scale > 1:
90
+ clamp_prob = (wav.abs() > 1).float().mean().item()
91
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93
+ wav.clamp_(-1, 1)
94
+
95
+
96
+ def normalize_audio(wav: torch.Tensor, normalize: bool = True,
97
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
98
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
99
+ loudness_compressor: bool = False, log_clipping: bool = False,
100
+ sample_rate: tp.Optional[int] = None,
101
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
102
+ """Normalize the audio according to the prescribed strategy (see after).
103
+
104
+ Args:
105
+ wav (torch.Tensor): Audio data.
106
+ normalize (bool): if `True` (default), normalizes according to the prescribed
107
+ strategy (see after). If `False`, the strategy is only used in case clipping
108
+ would happen.
109
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
110
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
111
+ with extra headroom to avoid clipping. 'clip' just clips.
112
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
113
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
114
+ than the `peak_clip` one to avoid further clipping.
115
+ loudness_headroom_db (float): Target loudness for loudness normalization.
116
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
117
+ log_clipping (bool): If True, basic logging on stderr when clipping still
118
+ occurs despite strategy (only for 'rms').
119
+ sample_rate (int): Sample rate for the audio data (required for loudness).
120
+ stem_name (str, optional): Stem name for clipping logging.
121
+ Returns:
122
+ torch.Tensor: Normalized audio.
123
+ """
124
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
125
+ scale_rms = 10 ** (-rms_headroom_db / 20)
126
+ if strategy == 'peak':
127
+ rescaling = (scale_peak / wav.abs().max())
128
+ if normalize or rescaling < 1:
129
+ wav = wav * rescaling
130
+ elif strategy == 'clip':
131
+ wav = wav.clamp(-scale_peak, scale_peak)
132
+ elif strategy == 'rms':
133
+ mono = wav.mean(dim=0)
134
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
135
+ if normalize or rescaling < 1:
136
+ wav = wav * rescaling
137
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
138
+ elif strategy == 'loudness':
139
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
140
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
141
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142
+ else:
143
+ assert wav.abs().max() < 1
144
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
145
+ return wav
146
+
147
+
148
+ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
149
+ """Convert audio to float 32 bits PCM format.
150
+ """
151
+ if wav.dtype.is_floating_point:
152
+ return wav
153
+ elif wav.dtype == torch.int16:
154
+ return wav.float() / 2**15
155
+ elif wav.dtype == torch.int32:
156
+ return wav.float() / 2**31
157
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
158
+
159
+
160
+ def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
161
+ """Convert audio to int 16 bits PCM format.
162
+
163
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
164
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
165
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
166
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
167
+ """
168
+ if wav.dtype.is_floating_point:
169
+ assert wav.abs().max() <= 1
170
+ candidate = (wav * 2 ** 15).round()
171
+ if candidate.max() >= 2 ** 15: # clipping would occur
172
+ candidate = (wav * (2 ** 15 - 1)).round()
173
+ return candidate.short()
174
+ else:
175
+ assert wav.dtype == torch.int16
176
+ return wav
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, meta = super().__getitem__(index)
68
+ return wav, AudioInfo(**meta.to_dict())
69
+
70
+
71
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
+ """Preprocess a single keyword or possible a list of keywords."""
73
+ if isinstance(value, list):
74
+ return get_keyword_list(value)
75
+ else:
76
+ return get_keyword(value)
77
+
78
+
79
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
+ """Preprocess a single keyword."""
81
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
+ return None
83
+ else:
84
+ return value.strip()
85
+
86
+
87
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
+ """Preprocess a single keyword."""
89
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
+ return None
91
+ else:
92
+ return value.strip().lower()
93
+
94
+
95
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
+ """Preprocess a list of keywords."""
97
+ if isinstance(values, str):
98
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
99
+ elif isinstance(values, float) and math.isnan(values):
100
+ values = []
101
+ if not isinstance(values, list):
102
+ logger.debug(f"Unexpected keyword list {values}")
103
+ values = [str(values)]
104
+
105
+ kws = [get_keyword(v) for v in values]
106
+ kw_list = [k for k in kws if k is not None]
107
+ if len(kw_list) == 0:
108
+ return None
109
+ else:
110
+ return kw_list
audiocraft/data/music_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of music tracks with rich metadata.
7
+ """
8
+ from dataclasses import dataclass, field, fields, replace
9
+ import gzip
10
+ import json
11
+ import logging
12
+ from pathlib import Path
13
+ import random
14
+ import typing as tp
15
+
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ AudioInfo,
21
+ get_keyword_list,
22
+ get_keyword,
23
+ get_string
24
+ )
25
+ from ..modules.conditioners import (
26
+ ConditioningAttributes,
27
+ JointEmbedCondition,
28
+ WavCondition,
29
+ )
30
+ from ..utils.utils import warn_once
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class MusicInfo(AudioInfo):
38
+ """Segment info augmented with music metadata.
39
+ """
40
+ # music-specific metadata
41
+ title: tp.Optional[str] = None
42
+ artist: tp.Optional[str] = None # anonymized artist id, used to ensure no overlap between splits
43
+ key: tp.Optional[str] = None
44
+ bpm: tp.Optional[float] = None
45
+ genre: tp.Optional[str] = None
46
+ moods: tp.Optional[list] = None
47
+ keywords: tp.Optional[list] = None
48
+ description: tp.Optional[str] = None
49
+ name: tp.Optional[str] = None
50
+ instrument: tp.Optional[str] = None
51
+ # original wav accompanying the metadata
52
+ self_wav: tp.Optional[WavCondition] = None
53
+ # dict mapping attributes names to tuple of wav, text and metadata
54
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
+
56
+ @property
57
+ def has_music_meta(self) -> bool:
58
+ return self.name is not None
59
+
60
+ def to_condition_attributes(self) -> ConditioningAttributes:
61
+ out = ConditioningAttributes()
62
+ for _field in fields(self):
63
+ key, value = _field.name, getattr(self, _field.name)
64
+ if key == 'self_wav':
65
+ out.wav[key] = value
66
+ elif key == 'joint_embed':
67
+ for embed_attribute, embed_cond in value.items():
68
+ out.joint_embed[embed_attribute] = embed_cond
69
+ else:
70
+ if isinstance(value, list):
71
+ value = ' '.join(value)
72
+ out.text[key] = value
73
+ return out
74
+
75
+ @staticmethod
76
+ def attribute_getter(attribute):
77
+ if attribute == 'bpm':
78
+ preprocess_func = get_bpm
79
+ elif attribute == 'key':
80
+ preprocess_func = get_musical_key
81
+ elif attribute in ['moods', 'keywords']:
82
+ preprocess_func = get_keyword_list
83
+ elif attribute in ['genre', 'name', 'instrument']:
84
+ preprocess_func = get_keyword
85
+ elif attribute in ['title', 'artist', 'description']:
86
+ preprocess_func = get_string
87
+ else:
88
+ preprocess_func = None
89
+ return preprocess_func
90
+
91
+ @classmethod
92
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
93
+ _dictionary: tp.Dict[str, tp.Any] = {}
94
+
95
+ # allow a subset of attributes to not be loaded from the dictionary
96
+ # these attributes may be populated later
97
+ post_init_attributes = ['self_wav', 'joint_embed']
98
+ optional_fields = ['keywords']
99
+
100
+ for _field in fields(cls):
101
+ if _field.name in post_init_attributes:
102
+ continue
103
+ elif _field.name not in dictionary:
104
+ if fields_required and _field.name not in optional_fields:
105
+ raise KeyError(f"Unexpected missing key: {_field.name}")
106
+ else:
107
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
108
+ value = dictionary[_field.name]
109
+ if preprocess_func:
110
+ value = preprocess_func(value)
111
+ _dictionary[_field.name] = value
112
+ return cls(**_dictionary)
113
+
114
+
115
+ def augment_music_info_description(music_info: MusicInfo, merge_text_p: float = 0.,
116
+ drop_desc_p: float = 0., drop_other_p: float = 0.) -> MusicInfo:
117
+ """Augment MusicInfo description with additional metadata fields and potential dropout.
118
+ Additional textual attributes are added given probability 'merge_text_conditions_p' and
119
+ the original textual description is dropped from the augmented description given probability drop_desc_p.
120
+
121
+ Args:
122
+ music_info (MusicInfo): The music metadata to augment.
123
+ merge_text_p (float): Probability of merging additional metadata to the description.
124
+ If provided value is 0, then no merging is performed.
125
+ drop_desc_p (float): Probability of dropping the original description on text merge.
126
+ if provided value is 0, then no drop out is performed.
127
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
128
+ Returns:
129
+ MusicInfo: The MusicInfo with augmented textual description.
130
+ """
131
+ def is_valid_field(field_name: str, field_value: tp.Any) -> bool:
132
+ valid_field_name = field_name in ['key', 'bpm', 'genre', 'moods', 'instrument', 'keywords']
133
+ valid_field_value = field_value is not None and isinstance(field_value, (int, float, str, list))
134
+ keep_field = random.uniform(0, 1) < drop_other_p
135
+ return valid_field_name and valid_field_value and keep_field
136
+
137
+ def process_value(v: tp.Any) -> str:
138
+ if isinstance(v, (int, float, str)):
139
+ return str(v)
140
+ if isinstance(v, list):
141
+ return ", ".join(v)
142
+ else:
143
+ raise ValueError(f"Unknown type for text value! ({type(v), v})")
144
+
145
+ description = music_info.description
146
+
147
+ metadata_text = ""
148
+ if random.uniform(0, 1) < merge_text_p:
149
+ meta_pairs = [f'{_field.name}: {process_value(getattr(music_info, _field.name))}'
150
+ for _field in fields(music_info) if is_valid_field(_field.name, getattr(music_info, _field.name))]
151
+ random.shuffle(meta_pairs)
152
+ metadata_text = ". ".join(meta_pairs)
153
+ description = description if not random.uniform(0, 1) < drop_desc_p else None
154
+ logger.debug(f"Applying text augmentation on MMI info. description: {description}, metadata: {metadata_text}")
155
+
156
+ if description is None:
157
+ description = metadata_text if len(metadata_text) > 1 else None
158
+ else:
159
+ description = ". ".join([description.rstrip('.'), metadata_text])
160
+ description = description.strip() if description else None
161
+
162
+ music_info = replace(music_info)
163
+ music_info.description = description
164
+ return music_info
165
+
166
+
167
+ class Paraphraser:
168
+ def __init__(self, paraphrase_source: tp.Union[str, Path], paraphrase_p: float = 0.):
169
+ self.paraphrase_p = paraphrase_p
170
+ open_fn = gzip.open if str(paraphrase_source).lower().endswith('.gz') else open
171
+ with open_fn(paraphrase_source, 'rb') as f: # type: ignore
172
+ self.paraphrase_source = json.loads(f.read())
173
+ logger.info(f"loaded paraphrasing source from: {paraphrase_source}")
174
+
175
+ def sample_paraphrase(self, audio_path: str, description: str):
176
+ if random.random() >= self.paraphrase_p:
177
+ return description
178
+ info_path = Path(audio_path).with_suffix('.json')
179
+ if info_path not in self.paraphrase_source:
180
+ warn_once(logger, f"{info_path} not in paraphrase source!")
181
+ return description
182
+ new_desc = random.choice(self.paraphrase_source[info_path])
183
+ logger.debug(f"{description} -> {new_desc}")
184
+ return new_desc
185
+
186
+
187
+ class MusicDataset(InfoAudioDataset):
188
+ """Music dataset is an AudioDataset with music-related metadata.
189
+
190
+ Args:
191
+ info_fields_required (bool): Whether to enforce having required fields.
192
+ merge_text_p (float): Probability of merging additional metadata to the description.
193
+ drop_desc_p (float): Probability of dropping the original description on text merge.
194
+ drop_other_p (float): Probability of dropping the other fields used for text augmentation.
195
+ joint_embed_attributes (list[str]): A list of attributes for which joint embedding metadata is returned.
196
+ paraphrase_source (str, optional): Path to the .json or .json.gz file containing the
197
+ paraphrases for the description. The json should be a dict with keys are the
198
+ original info path (e.g. track_path.json) and each value is a list of possible
199
+ paraphrased.
200
+ paraphrase_p (float): probability of taking a paraphrase.
201
+
202
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
203
+ """
204
+ def __init__(self, *args, info_fields_required: bool = True,
205
+ merge_text_p: float = 0., drop_desc_p: float = 0., drop_other_p: float = 0.,
206
+ joint_embed_attributes: tp.List[str] = [],
207
+ paraphrase_source: tp.Optional[str] = None, paraphrase_p: float = 0,
208
+ **kwargs):
209
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
210
+ super().__init__(*args, **kwargs)
211
+ self.info_fields_required = info_fields_required
212
+ self.merge_text_p = merge_text_p
213
+ self.drop_desc_p = drop_desc_p
214
+ self.drop_other_p = drop_other_p
215
+ self.joint_embed_attributes = joint_embed_attributes
216
+ self.paraphraser = None
217
+ if paraphrase_source is not None:
218
+ self.paraphraser = Paraphraser(paraphrase_source, paraphrase_p)
219
+
220
+ def __getitem__(self, index):
221
+ wav, info = super().__getitem__(index)
222
+ info_data = info.to_dict()
223
+ music_info_path = Path(info.meta.path).with_suffix('.json')
224
+
225
+ if Path(music_info_path).exists():
226
+ with open(music_info_path, 'r') as json_file:
227
+ music_data = json.load(json_file)
228
+ music_data.update(info_data)
229
+ music_info = MusicInfo.from_dict(music_data, fields_required=self.info_fields_required)
230
+ if self.paraphraser is not None:
231
+ music_info.description = self.paraphraser.sample(music_info.meta.path, music_info.description)
232
+ if self.merge_text_p:
233
+ music_info = augment_music_info_description(
234
+ music_info, self.merge_text_p, self.drop_desc_p, self.drop_other_p)
235
+ else:
236
+ music_info = MusicInfo.from_dict(info_data, fields_required=False)
237
+
238
+ music_info.self_wav = WavCondition(
239
+ wav=wav[None], length=torch.tensor([info.n_frames]),
240
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
241
+
242
+ for att in self.joint_embed_attributes:
243
+ att_value = getattr(music_info, att)
244
+ joint_embed_cond = JointEmbedCondition(
245
+ wav[None], [att_value], torch.tensor([info.n_frames]),
246
+ sample_rate=[info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
247
+ music_info.joint_embed[att] = joint_embed_cond
248
+
249
+ return wav, music_info
250
+
251
+
252
+ def get_musical_key(value: tp.Optional[str]) -> tp.Optional[str]:
253
+ """Preprocess key keywords, discarding them if there are multiple key defined."""
254
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
255
+ return None
256
+ elif ',' in value:
257
+ # For now, we discard when multiple keys are defined separated with comas
258
+ return None
259
+ else:
260
+ return value.strip().lower()
261
+
262
+
263
+ def get_bpm(value: tp.Optional[str]) -> tp.Optional[float]:
264
+ """Preprocess to a float."""
265
+ if value is None:
266
+ return None
267
+ try:
268
+ return float(value)
269
+ except ValueError:
270
+ return None
audiocraft/data/sound_dataset.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dataset of audio with a simple description.
7
+ """
8
+
9
+ from dataclasses import dataclass, fields, replace
10
+ import json
11
+ from pathlib import Path
12
+ import random
13
+ import typing as tp
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ from .info_audio_dataset import (
19
+ InfoAudioDataset,
20
+ get_keyword_or_keyword_list
21
+ )
22
+ from ..modules.conditioners import (
23
+ ConditioningAttributes,
24
+ SegmentWithAttributes,
25
+ WavCondition,
26
+ )
27
+
28
+
29
+ EPS = torch.finfo(torch.float32).eps
30
+ TARGET_LEVEL_LOWER = -35
31
+ TARGET_LEVEL_UPPER = -15
32
+
33
+
34
+ @dataclass
35
+ class SoundInfo(SegmentWithAttributes):
36
+ """Segment info augmented with Sound metadata.
37
+ """
38
+ description: tp.Optional[str] = None
39
+ self_wav: tp.Optional[torch.Tensor] = None
40
+
41
+ @property
42
+ def has_sound_meta(self) -> bool:
43
+ return self.description is not None
44
+
45
+ def to_condition_attributes(self) -> ConditioningAttributes:
46
+ out = ConditioningAttributes()
47
+
48
+ for _field in fields(self):
49
+ key, value = _field.name, getattr(self, _field.name)
50
+ if key == 'self_wav':
51
+ out.wav[key] = value
52
+ else:
53
+ out.text[key] = value
54
+ return out
55
+
56
+ @staticmethod
57
+ def attribute_getter(attribute):
58
+ if attribute == 'description':
59
+ preprocess_func = get_keyword_or_keyword_list
60
+ else:
61
+ preprocess_func = None
62
+ return preprocess_func
63
+
64
+ @classmethod
65
+ def from_dict(cls, dictionary: dict, fields_required: bool = False):
66
+ _dictionary: tp.Dict[str, tp.Any] = {}
67
+
68
+ # allow a subset of attributes to not be loaded from the dictionary
69
+ # these attributes may be populated later
70
+ post_init_attributes = ['self_wav']
71
+
72
+ for _field in fields(cls):
73
+ if _field.name in post_init_attributes:
74
+ continue
75
+ elif _field.name not in dictionary:
76
+ if fields_required:
77
+ raise KeyError(f"Unexpected missing key: {_field.name}")
78
+ else:
79
+ preprocess_func: tp.Optional[tp.Callable] = cls.attribute_getter(_field.name)
80
+ value = dictionary[_field.name]
81
+ if preprocess_func:
82
+ value = preprocess_func(value)
83
+ _dictionary[_field.name] = value
84
+ return cls(**_dictionary)
85
+
86
+
87
+ class SoundDataset(InfoAudioDataset):
88
+ """Sound audio dataset: Audio dataset with environmental sound-specific metadata.
89
+
90
+ Args:
91
+ info_fields_required (bool): Whether all the mandatory metadata fields should be in the loaded metadata.
92
+ external_metadata_source (tp.Optional[str]): Folder containing JSON metadata for the corresponding dataset.
93
+ The metadata files contained in this folder are expected to match the stem of the audio file with
94
+ a json extension.
95
+ aug_p (float): Probability of performing audio mixing augmentation on the batch.
96
+ mix_p (float): Proportion of batch items that are mixed together when applying audio mixing augmentation.
97
+ mix_snr_low (int): Lowerbound for SNR value sampled for mixing augmentation.
98
+ mix_snr_high (int): Upperbound for SNR value sampled for mixing augmentation.
99
+ mix_min_overlap (float): Minimum overlap between audio files when performing mixing augmentation.
100
+ kwargs: Additional arguments for AudioDataset.
101
+
102
+ See `audiocraft.data.info_audio_dataset.InfoAudioDataset` for full initialization arguments.
103
+ """
104
+ def __init__(
105
+ self,
106
+ *args,
107
+ info_fields_required: bool = True,
108
+ external_metadata_source: tp.Optional[str] = None,
109
+ aug_p: float = 0.,
110
+ mix_p: float = 0.,
111
+ mix_snr_low: int = -5,
112
+ mix_snr_high: int = 5,
113
+ mix_min_overlap: float = 0.5,
114
+ **kwargs
115
+ ):
116
+ kwargs['return_info'] = True # We require the info for each song of the dataset.
117
+ super().__init__(*args, **kwargs)
118
+ self.info_fields_required = info_fields_required
119
+ self.external_metadata_source = external_metadata_source
120
+ self.aug_p = aug_p
121
+ self.mix_p = mix_p
122
+ if self.aug_p > 0:
123
+ assert self.mix_p > 0, "Expecting some mixing proportion mix_p if aug_p > 0"
124
+ assert self.channels == 1, "SoundDataset with audio mixing considers only monophonic audio"
125
+ self.mix_snr_low = mix_snr_low
126
+ self.mix_snr_high = mix_snr_high
127
+ self.mix_min_overlap = mix_min_overlap
128
+
129
+ def _get_info_path(self, path: tp.Union[str, Path]) -> Path:
130
+ """Get path of JSON with metadata (description, etc.).
131
+ If there exists a JSON with the same name as 'path.name', then it will be used.
132
+ Else, such JSON will be searched for in an external json source folder if it exists.
133
+ """
134
+ info_path = Path(path).with_suffix('.json')
135
+ if Path(info_path).exists():
136
+ return info_path
137
+ elif self.external_metadata_source and (Path(self.external_metadata_source) / info_path.name).exists():
138
+ return Path(self.external_metadata_source) / info_path.name
139
+ else:
140
+ raise Exception(f"Unable to find a metadata JSON for path: {path}")
141
+
142
+ def __getitem__(self, index):
143
+ wav, info = super().__getitem__(index)
144
+ info_data = info.to_dict()
145
+ info_path = self._get_info_path(info.meta.path)
146
+ if Path(info_path).exists():
147
+ with open(info_path, 'r') as json_file:
148
+ sound_data = json.load(json_file)
149
+ sound_data.update(info_data)
150
+ sound_info = SoundInfo.from_dict(sound_data, fields_required=self.info_fields_required)
151
+ # if there are multiple descriptions, sample one randomly
152
+ if isinstance(sound_info.description, list):
153
+ sound_info.description = random.choice(sound_info.description)
154
+ else:
155
+ sound_info = SoundInfo.from_dict(info_data, fields_required=False)
156
+
157
+ sound_info.self_wav = WavCondition(
158
+ wav=wav[None], length=torch.tensor([info.n_frames]),
159
+ sample_rate=[sound_info.sample_rate], path=[info.meta.path], seek_time=[info.seek_time])
160
+
161
+ return wav, sound_info
162
+
163
+ def collater(self, samples):
164
+ # when training, audio mixing is performed in the collate function
165
+ wav, sound_info = super().collater(samples) # SoundDataset always returns infos
166
+ if self.aug_p > 0:
167
+ wav, sound_info = mix_samples(wav, sound_info, self.aug_p, self.mix_p,
168
+ snr_low=self.mix_snr_low, snr_high=self.mix_snr_high,
169
+ min_overlap=self.mix_min_overlap)
170
+ return wav, sound_info
171
+
172
+
173
+ def rms_f(x: torch.Tensor) -> torch.Tensor:
174
+ return (x ** 2).mean(1).pow(0.5)
175
+
176
+
177
+ def normalize(audio: torch.Tensor, target_level: int = -25) -> torch.Tensor:
178
+ """Normalize the signal to the target level."""
179
+ rms = rms_f(audio)
180
+ scalar = 10 ** (target_level / 20) / (rms + EPS)
181
+ audio = audio * scalar.unsqueeze(1)
182
+ return audio
183
+
184
+
185
+ def is_clipped(audio: torch.Tensor, clipping_threshold: float = 0.99) -> torch.Tensor:
186
+ return (abs(audio) > clipping_threshold).any(1)
187
+
188
+
189
+ def mix_pair(src: torch.Tensor, dst: torch.Tensor, min_overlap: float) -> torch.Tensor:
190
+ start = random.randint(0, int(src.shape[1] * (1 - min_overlap)))
191
+ remainder = src.shape[1] - start
192
+ if dst.shape[1] > remainder:
193
+ src[:, start:] = src[:, start:] + dst[:, :remainder]
194
+ else:
195
+ src[:, start:start+dst.shape[1]] = src[:, start:start+dst.shape[1]] + dst
196
+ return src
197
+
198
+
199
+ def snr_mixer(clean: torch.Tensor, noise: torch.Tensor, snr: int, min_overlap: float,
200
+ target_level: int = -25, clipping_threshold: float = 0.99) -> torch.Tensor:
201
+ """Function to mix clean speech and noise at various SNR levels.
202
+
203
+ Args:
204
+ clean (torch.Tensor): Clean audio source to mix, of shape [B, T].
205
+ noise (torch.Tensor): Noise audio source to mix, of shape [B, T].
206
+ snr (int): SNR level when mixing.
207
+ min_overlap (float): Minimum overlap between the two mixed sources.
208
+ target_level (int): Gain level in dB.
209
+ clipping_threshold (float): Threshold for clipping the audio.
210
+ Returns:
211
+ torch.Tensor: The mixed audio, of shape [B, T].
212
+ """
213
+ if clean.shape[1] > noise.shape[1]:
214
+ noise = torch.nn.functional.pad(noise, (0, clean.shape[1] - noise.shape[1]))
215
+ else:
216
+ noise = noise[:, :clean.shape[1]]
217
+
218
+ # normalizing to -25 dB FS
219
+ clean = clean / (clean.max(1)[0].abs().unsqueeze(1) + EPS)
220
+ clean = normalize(clean, target_level)
221
+ rmsclean = rms_f(clean)
222
+
223
+ noise = noise / (noise.max(1)[0].abs().unsqueeze(1) + EPS)
224
+ noise = normalize(noise, target_level)
225
+ rmsnoise = rms_f(noise)
226
+
227
+ # set the noise level for a given SNR
228
+ noisescalar = (rmsclean / (10 ** (snr / 20)) / (rmsnoise + EPS)).unsqueeze(1)
229
+ noisenewlevel = noise * noisescalar
230
+
231
+ # mix noise and clean speech
232
+ noisyspeech = mix_pair(clean, noisenewlevel, min_overlap)
233
+
234
+ # randomly select RMS value between -15 dBFS and -35 dBFS and normalize noisyspeech with that value
235
+ # there is a chance of clipping that might happen with very less probability, which is not a major issue.
236
+ noisy_rms_level = np.random.randint(TARGET_LEVEL_LOWER, TARGET_LEVEL_UPPER)
237
+ rmsnoisy = rms_f(noisyspeech)
238
+ scalarnoisy = (10 ** (noisy_rms_level / 20) / (rmsnoisy + EPS)).unsqueeze(1)
239
+ noisyspeech = noisyspeech * scalarnoisy
240
+ clean = clean * scalarnoisy
241
+ noisenewlevel = noisenewlevel * scalarnoisy
242
+
243
+ # final check to see if there are any amplitudes exceeding +/- 1. If so, normalize all the signals accordingly
244
+ clipped = is_clipped(noisyspeech)
245
+ if clipped.any():
246
+ noisyspeech_maxamplevel = noisyspeech[clipped].max(1)[0].abs().unsqueeze(1) / (clipping_threshold - EPS)
247
+ noisyspeech[clipped] = noisyspeech[clipped] / noisyspeech_maxamplevel
248
+
249
+ return noisyspeech
250
+
251
+
252
+ def snr_mix(src: torch.Tensor, dst: torch.Tensor, snr_low: int, snr_high: int, min_overlap: float):
253
+ if snr_low == snr_high:
254
+ snr = snr_low
255
+ else:
256
+ snr = np.random.randint(snr_low, snr_high)
257
+ mix = snr_mixer(src, dst, snr, min_overlap)
258
+ return mix
259
+
260
+
261
+ def mix_text(src_text: str, dst_text: str):
262
+ """Mix text from different sources by concatenating them."""
263
+ if src_text == dst_text:
264
+ return src_text
265
+ return src_text + " " + dst_text
266
+
267
+
268
+ def mix_samples(wavs: torch.Tensor, infos: tp.List[SoundInfo], aug_p: float, mix_p: float,
269
+ snr_low: int, snr_high: int, min_overlap: float):
270
+ """Mix samples within a batch, summing the waveforms and concatenating the text infos.
271
+
272
+ Args:
273
+ wavs (torch.Tensor): Audio tensors of shape [B, C, T].
274
+ infos (list[SoundInfo]): List of SoundInfo items corresponding to the audio.
275
+ aug_p (float): Augmentation probability.
276
+ mix_p (float): Proportion of items in the batch to mix (and merge) together.
277
+ snr_low (int): Lowerbound for sampling SNR.
278
+ snr_high (int): Upperbound for sampling SNR.
279
+ min_overlap (float): Minimum overlap between mixed samples.
280
+ Returns:
281
+ tuple[torch.Tensor, list[SoundInfo]]: A tuple containing the mixed wavs
282
+ and mixed SoundInfo for the given batch.
283
+ """
284
+ # no mixing to perform within the batch
285
+ if mix_p == 0:
286
+ return wavs, infos
287
+
288
+ if random.uniform(0, 1) < aug_p:
289
+ # perform all augmentations on waveforms as [B, T]
290
+ # randomly picking pairs of audio to mix
291
+ assert wavs.size(1) == 1, f"Mix samples requires monophonic audio but C={wavs.size(1)}"
292
+ wavs = wavs.mean(dim=1, keepdim=False)
293
+ B, T = wavs.shape
294
+ k = int(mix_p * B)
295
+ mixed_sources_idx = torch.randperm(B)[:k]
296
+ mixed_targets_idx = torch.randperm(B)[:k]
297
+ aug_wavs = snr_mix(
298
+ wavs[mixed_sources_idx],
299
+ wavs[mixed_targets_idx],
300
+ snr_low,
301
+ snr_high,
302
+ min_overlap,
303
+ )
304
+ # mixing textual descriptions in metadata
305
+ descriptions = [info.description for info in infos]
306
+ aug_infos = []
307
+ for i, j in zip(mixed_sources_idx, mixed_targets_idx):
308
+ text = mix_text(descriptions[i], descriptions[j])
309
+ m = replace(infos[i])
310
+ m.description = text
311
+ aug_infos.append(m)
312
+
313
+ # back to [B, C, T]
314
+ aug_wavs = aug_wavs.unsqueeze(1)
315
+ assert aug_wavs.shape[0] > 0, "Samples mixing returned empty batch."
316
+ assert aug_wavs.dim() == 3, f"Returned wav should be [B, C, T] but dim = {aug_wavs.dim()}"
317
+ assert aug_wavs.shape[0] == len(aug_infos), "Mismatch between number of wavs and infos in the batch"
318
+
319
+ return aug_wavs, aug_infos # [B, C, T]
320
+ else:
321
+ # randomly pick samples in the batch to match
322
+ # the batch size when performing audio mixing
323
+ B, C, T = wavs.shape
324
+ k = int(mix_p * B)
325
+ wav_idx = torch.randperm(B)[:k]
326
+ wavs = wavs[wav_idx]
327
+ infos = [infos[i] for i in wav_idx]
328
+ assert wavs.shape[0] == len(infos), "Mismatch between number of wavs and infos in the batch"
329
+
330
+ return wavs, infos # [B, C, T]
audiocraft/data/zip.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
+
9
+ import typing
10
+ import zipfile
11
+
12
+ from dataclasses import dataclass
13
+ from functools import lru_cache
14
+ from typing_extensions import Literal
15
+
16
+
17
+ DEFAULT_SIZE = 32
18
+ MODE = Literal['r', 'w', 'x', 'a']
19
+
20
+
21
+ @dataclass(order=True)
22
+ class PathInZip:
23
+ """Hold a path of file within a zip file.
24
+
25
+ Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
+ Let's assume there is a zip file /some/location/foo.zip
28
+ and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
+ """
31
+
32
+ INFO_PATH_SEP = ':'
33
+ zip_path: str
34
+ file_path: str
35
+
36
+ def __init__(self, path: str) -> None:
37
+ split_path = path.split(self.INFO_PATH_SEP)
38
+ assert len(split_path) == 2
39
+ self.zip_path, self.file_path = split_path
40
+
41
+ @classmethod
42
+ def from_paths(cls, zip_path: str, file_path: str):
43
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44
+
45
+ def __str__(self) -> str:
46
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
47
+
48
+
49
+ def _open_zip(path: str, mode: MODE = 'r'):
50
+ return zipfile.ZipFile(path, mode)
51
+
52
+
53
+ _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54
+
55
+
56
+ def set_zip_cache_size(max_size: int):
57
+ """Sets the maximal LRU caching for zip file opening.
58
+
59
+ Args:
60
+ max_size (int): the maximal LRU cache.
61
+ """
62
+ global _cached_open_zip
63
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
64
+
65
+
66
+ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67
+ """Opens a file stored inside a zip and returns a file-like object.
68
+
69
+ Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
+ Returns:
73
+ A file-like object for PathInZip.
74
+ """
75
+ zf = _cached_open_zip(path_in_zip.zip_path)
76
+ return zf.open(path_in_zip.file_path)
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/grids/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Dora Grids."""
audiocraft/grids/_base_explorers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ import time
9
+ import typing as tp
10
+ from dora import Explorer
11
+ import treetable as tt
12
+
13
+
14
+ def get_sheep_ping(sheep) -> tp.Optional[str]:
15
+ """Return the amount of time since the Sheep made some update
16
+ to its log. Returns a str using the relevant time unit."""
17
+ ping = None
18
+ if sheep.log is not None and sheep.log.exists():
19
+ delta = time.time() - sheep.log.stat().st_mtime
20
+ if delta > 3600 * 24:
21
+ ping = f'{delta / (3600 * 24):.1f}d'
22
+ elif delta > 3600:
23
+ ping = f'{delta / (3600):.1f}h'
24
+ elif delta > 60:
25
+ ping = f'{delta / 60:.1f}m'
26
+ else:
27
+ ping = f'{delta:.1f}s'
28
+ return ping
29
+
30
+
31
+ class BaseExplorer(ABC, Explorer):
32
+ """Base explorer for AudioCraft grids.
33
+
34
+ All task specific solvers are expected to implement the `get_grid_metrics`
35
+ method to specify logic about metrics to display for a given task.
36
+
37
+ If additional stages are used, the child explorer must define how to handle
38
+ these new stages in the `process_history` and `process_sheep` methods.
39
+ """
40
+ def stages(self):
41
+ return ["train", "valid", "evaluate"]
42
+
43
+ def get_grid_meta(self):
44
+ """Returns the list of Meta information to display for each XP/job.
45
+ """
46
+ return [
47
+ tt.leaf("index", align=">"),
48
+ tt.leaf("name", wrap=140),
49
+ tt.leaf("state"),
50
+ tt.leaf("sig", align=">"),
51
+ tt.leaf("sid", align="<"),
52
+ ]
53
+
54
+ @abstractmethod
55
+ def get_grid_metrics(self):
56
+ """Return the metrics that should be displayed in the tracking table.
57
+ """
58
+ ...
59
+
60
+ def process_sheep(self, sheep, history):
61
+ train = {
62
+ "epoch": len(history),
63
+ }
64
+ parts = {"train": train}
65
+ for metrics in history:
66
+ for key, sub in metrics.items():
67
+ part = parts.get(key, {})
68
+ if 'duration' in sub:
69
+ # Convert to minutes for readability.
70
+ sub['duration'] = sub['duration'] / 60.
71
+ part.update(sub)
72
+ parts[key] = part
73
+ ping = get_sheep_ping(sheep)
74
+ if ping is not None:
75
+ for name in self.stages():
76
+ if name not in parts:
77
+ parts[name] = {}
78
+ # Add the ping to each part for convenience.
79
+ parts[name]['ping'] = ping
80
+ return parts
audiocraft/grids/audiogen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """AudioGen grids."""
audiocraft/grids/audiogen/audiogen_base_16khz.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ..musicgen._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=64, partition=partitions)
15
+ launcher.bind_(solver='audiogen/audiogen_base_16khz')
16
+ # replace this by the desired environmental sound dataset
17
+ launcher.bind_(dset='internal/sounds_16khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+
22
+ launcher.bind_(fsdp)
23
+ launcher(medium)
audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluation with objective metrics for the pretrained AudioGen models.
9
+ This grid takes signature from the training grid and runs evaluation-only stage.
10
+
11
+ When running the grid for the first time, please use:
12
+ REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
13
+ and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
+
15
+ Note that you need the proper metrics external libraries setup to use all
16
+ the objective metrics activated in this grid. Refer to the README for more information.
17
+ """
18
+
19
+ import os
20
+
21
+ from ..musicgen._explorers import GenerationEvalExplorer
22
+ from ...environment import AudioCraftEnvironment
23
+ from ... import train
24
+
25
+
26
+ def eval(launcher, batch_size: int = 32):
27
+ opts = {
28
+ 'dset': 'audio/audiocaps_16khz',
29
+ 'solver/audiogen/evaluation': 'objective_eval',
30
+ 'execute_only': 'evaluate',
31
+ '+dataset.evaluate.batch_size': batch_size,
32
+ '+metrics.fad.tf.batch_size': 32,
33
+ }
34
+ # binary for FAD computation: replace this path with your own path
35
+ metrics_opts = {
36
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
37
+ }
38
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
39
+ opt2 = {'transformer_lm.two_step_cfg': True}
40
+
41
+ sub = launcher.bind(opts)
42
+ sub.bind_(metrics_opts)
43
+
44
+ # base objective metrics
45
+ sub(opt1, opt2)
46
+
47
+
48
+ @GenerationEvalExplorer
49
+ def explorer(launcher):
50
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
51
+ launcher.slurm_(gpus=4, partition=partitions)
52
+
53
+ if 'REGEN' not in os.environ:
54
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
55
+ with launcher.job_array():
56
+ for sig in folder.iterdir():
57
+ if not sig.is_symlink():
58
+ continue
59
+ xp = train.main.get_xp_from_sig(sig.name)
60
+ launcher(xp.argv)
61
+ return
62
+
63
+ audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
64
+ audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
65
+
66
+ audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
67
+ audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
68
+ eval(audiogen_base_medium, batch_size=128)
audiocraft/grids/compression/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """EnCodec grids."""
audiocraft/grids/compression/_explorers.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class CompressionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "evaluate"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("bandwidth", ".2f"),
37
+ tt.leaf("adv", ".4f"),
38
+ tt.leaf("d_loss", ".4f"),
39
+ ],
40
+ align=">",
41
+ ),
42
+ tt.group(
43
+ "valid",
44
+ [
45
+ tt.leaf("bandwidth", ".2f"),
46
+ tt.leaf("adv", ".4f"),
47
+ tt.leaf("msspec", ".4f"),
48
+ tt.leaf("sisnr", ".2f"),
49
+ ],
50
+ align=">",
51
+ ),
52
+ tt.group(
53
+ "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
54
+ ),
55
+ ]
audiocraft/grids/compression/debug.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid is a minimal example for debugging compression task
13
+ and how to override parameters directly in a grid.
14
+ Learn more about dora grids: https://github.com/facebookresearch/dora
15
+ """
16
+
17
+ from ._explorers import CompressionExplorer
18
+ from ...environment import AudioCraftEnvironment
19
+
20
+
21
+ @CompressionExplorer
22
+ def explorer(launcher):
23
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
24
+ launcher.slurm_(gpus=2, partition=partitions)
25
+ launcher.bind_(solver='compression/debug')
26
+
27
+ with launcher.job_array():
28
+ # base debug task using config from solver=compression/debug
29
+ launcher()
30
+ # we can override parameters in the grid to launch additional xps
31
+ launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
audiocraft/grids/compression/encodec_audiogen_16khz.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
24
+ # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_audiogen_16khz')
26
+ # replace this by the desired sound dataset
27
+ launcher.bind_(dset='internal/sounds_16khz')
28
+ # launch xp
29
+ launcher()
audiocraft/grids/compression/encodec_base_24khz.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a base causal EnCodec model at 24 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # base causal EnCodec trained on monophonic audio sampled at 24 kHz
24
+ launcher.bind_(solver='compression/encodec_base_24khz')
25
+ # replace this by the desired dataset
26
+ launcher.bind_(dset='audio/example')
27
+ # launch xp
28
+ launcher()
audiocraft/grids/compression/encodec_musicgen_32khz.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Grid search file, simply list all the exp you want in `explorer`.
9
+ Any new exp added there will be scheduled.
10
+ You can cancel and experiment by commenting its line.
11
+
12
+ This grid shows how to train a MusicGen EnCodec model at 32 kHz.
13
+ """
14
+
15
+ from ._explorers import CompressionExplorer
16
+ from ...environment import AudioCraftEnvironment
17
+
18
+
19
+ @CompressionExplorer
20
+ def explorer(launcher):
21
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22
+ launcher.slurm_(gpus=8, partition=partitions)
23
+ # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
24
+ # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
25
+ launcher.bind_(solver='compression/encodec_musicgen_32khz')
26
+ # replace this by the desired music dataset
27
+ launcher.bind_(dset='internal/music_400k_32khz')
28
+ # launch xp
29
+ launcher()
30
+ launcher({
31
+ 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
32
+ 'label': 'visqol',
33
+ 'evaluate.metrics.visqol': True
34
+ })
audiocraft/grids/diffusion/4_bands_base_32khz.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Training of the 4 diffusion models described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ from ._explorers import DiffusionExplorer
14
+
15
+
16
+ @DiffusionExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(gpus=4, partition='learnfair')
19
+
20
+ launcher.bind_({'solver': 'diffusion/default',
21
+ 'dset': 'internal/music_10k_32khz'})
22
+
23
+ with launcher.job_array():
24
+ launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
25
+ launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
26
+ launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
27
+ launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
audiocraft/grids/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Diffusion grids."""
audiocraft/grids/diffusion/_explorers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class DiffusionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("loss", ".3%"),
37
+ ],
38
+ align=">",
39
+ ),
40
+ tt.group(
41
+ "valid",
42
+ [
43
+ tt.leaf("loss", ".3%"),
44
+ # tt.leaf("loss_0", ".3%"),
45
+ ],
46
+ align=">",
47
+ ),
48
+ tt.group(
49
+ "valid_ema",
50
+ [
51
+ tt.leaf("loss", ".3%"),
52
+ # tt.leaf("loss_0", ".3%"),
53
+ ],
54
+ align=">",
55
+ ),
56
+ tt.group(
57
+ "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
58
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
59
+ tt.leaf("rvm_3", ".4f"), ], align=">"
60
+ ),
61
+ tt.group(
62
+ "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
63
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
64
+ tt.leaf("rvm_3", ".4f")], align=">"
65
+ ),
66
+ ]
audiocraft/grids/musicgen/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """MusicGen grids."""
audiocraft/grids/musicgen/_explorers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import treetable as tt
10
+
11
+ from .._base_explorers import BaseExplorer
12
+
13
+
14
+ class LMExplorer(BaseExplorer):
15
+ eval_metrics: tp.List[str] = []
16
+
17
+ def stages(self) -> tp.List[str]:
18
+ return ['train', 'valid']
19
+
20
+ def get_grid_metrics(self):
21
+ """Return the metrics that should be displayed in the tracking table."""
22
+ return [
23
+ tt.group(
24
+ 'train',
25
+ [
26
+ tt.leaf('epoch'),
27
+ tt.leaf('duration', '.1f'), # duration in minutes
28
+ tt.leaf('ping'),
29
+ tt.leaf('ce', '.4f'), # cross entropy
30
+ tt.leaf("ppl", '.3f'), # perplexity
31
+ ],
32
+ align='>',
33
+ ),
34
+ tt.group(
35
+ 'valid',
36
+ [
37
+ tt.leaf('ce', '.4f'),
38
+ tt.leaf('ppl', '.3f'),
39
+ tt.leaf('best_ppl', '.3f'),
40
+ ],
41
+ align='>',
42
+ ),
43
+ ]
44
+
45
+ def process_sheep(self, sheep, history):
46
+ parts = super().process_sheep(sheep, history)
47
+
48
+ track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
49
+ best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
50
+
51
+ def comparator(mode, a, b):
52
+ return a < b if mode == 'lower' else a > b
53
+
54
+ for metrics in history:
55
+ for key, sub in metrics.items():
56
+ for metric in track_by:
57
+ # for the validation set, keep track of best metrics (ppl in this example)
58
+ # this is so we can conveniently compare metrics between runs in the grid
59
+ if key == 'valid' and metric in sub and comparator(
60
+ track_by[metric], sub[metric], best_metrics[metric]
61
+ ):
62
+ best_metrics[metric] = sub[metric]
63
+
64
+ if 'valid' in parts:
65
+ parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
66
+ return parts
67
+
68
+
69
+ class GenerationEvalExplorer(BaseExplorer):
70
+ eval_metrics: tp.List[str] = []
71
+
72
+ def stages(self) -> tp.List[str]:
73
+ return ['evaluate']
74
+
75
+ def get_grid_metrics(self):
76
+ """Return the metrics that should be displayed in the tracking table."""
77
+ return [
78
+ tt.group(
79
+ 'evaluate',
80
+ [
81
+ tt.leaf('epoch', '.3f'),
82
+ tt.leaf('duration', '.1f'),
83
+ tt.leaf('ping'),
84
+ tt.leaf('ce', '.4f'),
85
+ tt.leaf('ppl', '.3f'),
86
+ tt.leaf('fad', '.3f'),
87
+ tt.leaf('kld', '.3f'),
88
+ tt.leaf('text_consistency', '.3f'),
89
+ tt.leaf('chroma_cosine', '.3f'),
90
+ ],
91
+ align='>',
92
+ ),
93
+ ]
audiocraft/grids/musicgen/musicgen_base_32khz.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ launcher.bind_(fsdp)
29
+
30
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
31
+ with launcher.job_array():
32
+ sub = launcher.bind()
33
+ sub()
34
+
35
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
36
+ with launcher.job_array():
37
+ sub = launcher.bind()
38
+ sub(medium, adam)
39
+
40
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
41
+ with launcher.job_array():
42
+ sub = launcher.bind()
43
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_base_cached_32khz.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ # BEGINNING OF CACHE WRITING JOBS.
29
+ cache_write = {
30
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
31
+ 'cache.write': True,
32
+ 'generate.every': 500,
33
+ 'evaluate.every': 500,
34
+ 'logging.log_updates': 50,
35
+ }
36
+
37
+ cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
38
+ cache_sub.bind_({'deadlock.use': True})
39
+ cache_sub.slurm_(gpus=8)
40
+ with launcher.job_array():
41
+ num_shards = 10 # total number of jobs running in parallel.
42
+ for shard in range(0, num_shards):
43
+ launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
44
+
45
+ # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
46
+ # OR SUFFICIENTLY AHEAD.
47
+ return
48
+
49
+ cache = {
50
+ 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
51
+ }
52
+ launcher.bind_(fsdp, cache)
53
+
54
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
55
+ with launcher.job_array():
56
+ sub = launcher.bind()
57
+ sub()
58
+
59
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
60
+ with launcher.job_array():
61
+ sub = launcher.bind()
62
+ sub(medium, adam)
63
+
64
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
65
+ with launcher.job_array():
66
+ sub = launcher.bind()
67
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_clapemb_32khz.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_base_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+ launcher.bind_(conditioner='clapemb2music')
19
+
20
+ fsdp = {'autocast': False, 'fsdp.use': True}
21
+ cache_path = {'conditioners.description.clap.cache_path':
22
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
23
+ text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
24
+
25
+ launcher.bind_(fsdp)
26
+
27
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
28
+ with launcher.job_array():
29
+ launcher()
30
+ launcher(text_wav_training_opt)
31
+ launcher(cache_path)
32
+ launcher(cache_path, text_wav_training_opt)
audiocraft/grids/musicgen/musicgen_melody_32khz.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from ._explorers import LMExplorer
8
+ from ...environment import AudioCraftEnvironment
9
+
10
+
11
+ @LMExplorer
12
+ def explorer(launcher):
13
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14
+ launcher.slurm_(gpus=32, partition=partitions)
15
+ launcher.bind_(solver='musicgen/musicgen_melody_32khz')
16
+ # replace this by the desired music dataset
17
+ launcher.bind_(dset='internal/music_400k_32khz')
18
+
19
+ fsdp = {'autocast': False, 'fsdp.use': True}
20
+ medium = {'model/lm/model_scale': 'medium'}
21
+ large = {'model/lm/model_scale': 'large'}
22
+
23
+ cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24
+ wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25
+
26
+ adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27
+
28
+ cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
29
+ '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
30
+
31
+ # CACHE GENERATION JOBS
32
+ n_cache_gen_jobs = 4
33
+ gen_sub = launcher.slurm(gpus=1)
34
+ gen_sub.bind_(
35
+ cache_path, {
36
+ # the cache is always computed over the whole file, so duration doesn't matter here.
37
+ 'dataset.segment_duration': 2.,
38
+ 'dataset.batch_size': 8,
39
+ 'dataset.train.permutation_on_files': True, # try to not repeat files.
40
+ 'optim.epochs': 10,
41
+ 'model/lm/model_scale': 'xsmall',
42
+
43
+ })
44
+ with gen_sub.job_array():
45
+ for gen_job in range(n_cache_gen_jobs):
46
+ gen_sub({'dataset.train.shuffle_seed': gen_job})
47
+
48
+ # ACTUAL TRAINING JOBS.
49
+ launcher.bind_(fsdp)
50
+
51
+ launcher.slurm_(gpus=32).bind_(label='32gpus')
52
+ with launcher.job_array():
53
+ sub = launcher.bind()
54
+ sub()
55
+ sub(cache_path)
56
+
57
+ launcher.slurm_(gpus=64).bind_(label='64gpus')
58
+ with launcher.job_array():
59
+ sub = launcher.bind()
60
+ sub(medium, adam)
61
+
62
+ launcher.slurm_(gpus=96).bind_(label='96gpus')
63
+ with launcher.job_array():
64
+ sub = launcher.bind()
65
+ sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Evaluation with objective metrics for the pretrained MusicGen models.
9
+ This grid takes signature from the training grid and runs evaluation-only stage.
10
+
11
+ When running the grid for the first time, please use:
12
+ REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
13
+ and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14
+
15
+ Note that you need the proper metrics external libraries setup to use all
16
+ the objective metrics activated in this grid. Refer to the README for more information.
17
+ """
18
+
19
+ import os
20
+
21
+ from ._explorers import GenerationEvalExplorer
22
+ from ...environment import AudioCraftEnvironment
23
+ from ... import train
24
+
25
+
26
+ def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
27
+ opts = {
28
+ 'dset': 'audio/musiccaps_32khz',
29
+ 'solver/musicgen/evaluation': 'objective_eval',
30
+ 'execute_only': 'evaluate',
31
+ '+dataset.evaluate.batch_size': batch_size,
32
+ '+metrics.fad.tf.batch_size': 16,
33
+ }
34
+ # chroma-specific evaluation
35
+ chroma_opts = {
36
+ 'dset': 'internal/music_400k_32khz',
37
+ 'dataset.evaluate.segment_duration': 30,
38
+ 'dataset.evaluate.num_samples': 1000,
39
+ 'evaluate.metrics.chroma_cosine': True,
40
+ 'evaluate.metrics.fad': False,
41
+ 'evaluate.metrics.kld': False,
42
+ 'evaluate.metrics.text_consistency': False,
43
+ }
44
+ # binary for FAD computation: replace this path with your own path
45
+ metrics_opts = {
46
+ 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
47
+ }
48
+ opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
49
+ opt2 = {'transformer_lm.two_step_cfg': True}
50
+
51
+ sub = launcher.bind(opts)
52
+ sub.bind_(metrics_opts)
53
+
54
+ # base objective metrics
55
+ sub(opt1, opt2)
56
+
57
+ if eval_melody:
58
+ # chroma-specific metrics
59
+ sub(opt1, opt2, chroma_opts)
60
+
61
+
62
+ @GenerationEvalExplorer
63
+ def explorer(launcher):
64
+ partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
65
+ launcher.slurm_(gpus=4, partition=partitions)
66
+
67
+ if 'REGEN' not in os.environ:
68
+ folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
69
+ with launcher.job_array():
70
+ for sig in folder.iterdir():
71
+ if not sig.is_symlink():
72
+ continue
73
+ xp = train.main.get_xp_from_sig(sig.name)
74
+ launcher(xp.argv)
75
+ return
76
+
77
+ with launcher.job_array():
78
+ musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
79
+ musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
80
+
81
+ # base musicgen models
82
+ musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
83
+ eval(musicgen_base_small, batch_size=128)
84
+
85
+ musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
86
+ musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
87
+ eval(musicgen_base_medium, batch_size=128)
88
+
89
+ musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
90
+ musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
91
+ eval(musicgen_base_large, batch_size=128)
92
+
93
+ # melody musicgen model
94
+ musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
95
+ musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
96
+
97
+ musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
98
+ musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
99
+ eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
audiocraft/losses/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Loss related classes and functions. In particular the loss balancer from
7
+ EnCodec, and the usual spectral losses."""
8
+
9
+ # flake8: noqa
10
+ from .balancer import Balancer
11
+ from .sisnr import SISNR
12
+ from .stftloss import (
13
+ LogSTFTMagnitudeLoss,
14
+ MRSTFTLoss,
15
+ SpectralConvergenceLoss,
16
+ STFTLoss
17
+ )
18
+ from .specloss import (
19
+ MelSpectrogramL1Loss,
20
+ MultiScaleMelSpectrogramLoss,
21
+ )
audiocraft/losses/balancer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import flashy
10
+ import torch
11
+ from torch import autograd
12
+
13
+
14
+ class Balancer:
15
+ """Loss balancer.
16
+
17
+ The loss balancer combines losses together to compute gradients for the backward.
18
+ Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
19
+ not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
20
+ `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
21
+ the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
22
+ going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
23
+ interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
24
+
25
+ Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
26
+ (with `avg` an exponential moving average over the updates),
27
+
28
+ G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
29
+
30
+ If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
31
+ standard sum of the partial gradients with the given weights.
32
+
33
+ A call to the backward method of the balancer will compute the the partial gradients,
34
+ combining all the losses and potentially rescaling the gradients,
35
+ which can help stabilize the training and reason about multiple losses with varying scales.
36
+ The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
37
+
38
+ Expected usage:
39
+
40
+ weights = {'loss_a': 1, 'loss_b': 4}
41
+ balancer = Balancer(weights, ...)
42
+ losses: dict = {}
43
+ losses['loss_a'] = compute_loss_a(x, y)
44
+ losses['loss_b'] = compute_loss_b(x, y)
45
+ if model.training():
46
+ effective_loss = balancer.backward(losses, x)
47
+
48
+ Args:
49
+ weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
50
+ from the backward method to match the weights keys to assign weight to each of the provided loss.
51
+ balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
52
+ overall gradient, rather than a constant multiplier.
53
+ total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
54
+ emay_decay (float): EMA decay for averaging the norms.
55
+ per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
56
+ when rescaling the gradients.
57
+ epsilon (float): Epsilon value for numerical stability.
58
+ monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
59
+ coming from each loss, when calling `backward()`.
60
+ """
61
+ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
62
+ ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
63
+ monitor: bool = False):
64
+ self.weights = weights
65
+ self.per_batch_item = per_batch_item
66
+ self.total_norm = total_norm or 1.
67
+ self.averager = flashy.averager(ema_decay or 1.)
68
+ self.epsilon = epsilon
69
+ self.monitor = monitor
70
+ self.balance_grads = balance_grads
71
+ self._metrics: tp.Dict[str, tp.Any] = {}
72
+
73
+ @property
74
+ def metrics(self):
75
+ return self._metrics
76
+
77
+ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
78
+ """Compute the backward and return the effective train loss, e.g. the loss obtained from
79
+ computing the effective weights. If `balance_grads` is True, the effective weights
80
+ are the one that needs to be applied to each gradient to respect the desired relative
81
+ scale of gradients coming from each loss.
82
+
83
+ Args:
84
+ losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
85
+ input (torch.Tensor): the input of the losses, typically the output of the model.
86
+ This should be the single point of dependence between the losses
87
+ and the model being trained.
88
+ """
89
+ norms = {}
90
+ grads = {}
91
+ for name, loss in losses.items():
92
+ # Compute partial derivative of the less with respect to the input.
93
+ grad, = autograd.grad(loss, [input], retain_graph=True)
94
+ if self.per_batch_item:
95
+ # We do not average the gradient over the batch dimension.
96
+ dims = tuple(range(1, grad.dim()))
97
+ norm = grad.norm(dim=dims, p=2).mean()
98
+ else:
99
+ norm = grad.norm(p=2)
100
+ norms[name] = norm
101
+ grads[name] = grad
102
+
103
+ count = 1
104
+ if self.per_batch_item:
105
+ count = len(grad)
106
+ # Average norms across workers. Theoretically we should average the
107
+ # squared norm, then take the sqrt, but it worked fine like that.
108
+ avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
109
+ # We approximate the total norm of the gradient as the sums of the norms.
110
+ # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
111
+ total = sum(avg_norms.values())
112
+
113
+ self._metrics = {}
114
+ if self.monitor:
115
+ # Store the ratio of the total gradient represented by each loss.
116
+ for k, v in avg_norms.items():
117
+ self._metrics[f'ratio_{k}'] = v / total
118
+
119
+ total_weights = sum([self.weights[k] for k in avg_norms])
120
+ assert total_weights > 0.
121
+ desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
122
+
123
+ out_grad = torch.zeros_like(input)
124
+ effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
125
+ for name, avg_norm in avg_norms.items():
126
+ if self.balance_grads:
127
+ # g_balanced = g / avg(||g||) * total_norm * desired_ratio
128
+ scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
129
+ else:
130
+ # We just do regular weighted sum of the gradients.
131
+ scale = self.weights[name]
132
+ out_grad.add_(grads[name], alpha=scale)
133
+ effective_loss += scale * losses[name].detach()
134
+ # Send the computed partial derivative with respect to the output of the model to the model.
135
+ input.backward(out_grad)
136
+ return effective_loss
audiocraft/losses/sisnr.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import typing as tp
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
16
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
17
+ with K the kernel size, by extracting frames with the given stride.
18
+ This will pad the input so that `F = ceil(T / K)`.
19
+ see https://github.com/pytorch/pytorch/issues/60466
20
+ """
21
+ *shape, length = a.shape
22
+ n_frames = math.ceil(length / stride)
23
+ tgt_length = (n_frames - 1) * stride + kernel_size
24
+ a = F.pad(a, (0, tgt_length - length))
25
+ strides = list(a.stride())
26
+ assert strides[-1] == 1, "data should be contiguous"
27
+ strides = strides[:-1] + [stride, 1]
28
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
29
+
30
+
31
+ def _center(x: torch.Tensor) -> torch.Tensor:
32
+ return x - x.mean(-1, True)
33
+
34
+
35
+ def _norm2(x: torch.Tensor) -> torch.Tensor:
36
+ return x.pow(2).sum(-1, True)
37
+
38
+
39
+ class SISNR(nn.Module):
40
+ """SISNR loss.
41
+
42
+ Input should be [B, C, T], output is scalar.
43
+
44
+ Args:
45
+ sample_rate (int): Sample rate.
46
+ segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
47
+ entire audio only.
48
+ overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
49
+ epsilon (float): Epsilon value for numerical stability.
50
+ """
51
+ def __init__(
52
+ self,
53
+ sample_rate: int = 16000,
54
+ segment: tp.Optional[float] = 20,
55
+ overlap: float = 0.5,
56
+ epsilon: float = torch.finfo(torch.float32).eps,
57
+ ):
58
+ super().__init__()
59
+ self.sample_rate = sample_rate
60
+ self.segment = segment
61
+ self.overlap = overlap
62
+ self.epsilon = epsilon
63
+
64
+ def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
65
+ B, C, T = ref_sig.shape
66
+ assert ref_sig.shape == out_sig.shape
67
+
68
+ if self.segment is None:
69
+ frame = T
70
+ stride = T
71
+ else:
72
+ frame = int(self.segment * self.sample_rate)
73
+ stride = int(frame * (1 - self.overlap))
74
+
75
+ epsilon = self.epsilon * frame # make epsilon prop to frame size.
76
+
77
+ gt = _unfold(ref_sig, frame, stride)
78
+ est = _unfold(out_sig, frame, stride)
79
+ if self.segment is None:
80
+ assert gt.shape[-1] == 1
81
+
82
+ gt = _center(gt)
83
+ est = _center(est)
84
+ dot = torch.einsum("bcft,bcft->bcf", gt, est)
85
+
86
+ proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
87
+ noise = est - proj
88
+
89
+ sisnr = 10 * (
90
+ torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
91
+ )
92
+ return -1 * sisnr[..., 0].mean()
audiocraft/losses/specloss.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import numpy as np
10
+ from torchaudio.transforms import MelSpectrogram
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+ from ..modules import pad_for_conv1d
16
+
17
+
18
+ class MelSpectrogramWrapper(nn.Module):
19
+ """Wrapper around MelSpectrogram torchaudio transform providing proper padding
20
+ and additional post-processing including log scaling.
21
+
22
+ Args:
23
+ n_mels (int): Number of mel bins.
24
+ n_fft (int): Number of fft.
25
+ hop_length (int): Hop size.
26
+ win_length (int): Window length.
27
+ n_mels (int): Number of mel bins.
28
+ sample_rate (int): Sample rate.
29
+ f_min (float or None): Minimum frequency.
30
+ f_max (float or None): Maximum frequency.
31
+ log (bool): Whether to scale with log.
32
+ normalized (bool): Whether to normalize the melspectrogram.
33
+ floor_level (float): Floor level based on human perception (default=1e-5).
34
+ """
35
+ def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None,
36
+ n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None,
37
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
38
+ super().__init__()
39
+ self.n_fft = n_fft
40
+ hop_length = int(hop_length)
41
+ self.hop_length = hop_length
42
+ self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length,
43
+ win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized,
44
+ window_fn=torch.hann_window, center=False)
45
+ self.floor_level = floor_level
46
+ self.log = log
47
+
48
+ def forward(self, x):
49
+ p = int((self.n_fft - self.hop_length) // 2)
50
+ if len(x.shape) == 2:
51
+ x = x.unsqueeze(1)
52
+ x = F.pad(x, (p, p), "reflect")
53
+ # Make sure that all the frames are full.
54
+ # The combination of `pad_for_conv1d` and the above padding
55
+ # will make the output of size ceil(T / hop).
56
+ x = pad_for_conv1d(x, self.n_fft, self.hop_length)
57
+ self.mel_transform.to(x.device)
58
+ mel_spec = self.mel_transform(x)
59
+ B, C, freqs, frame = mel_spec.shape
60
+ if self.log:
61
+ mel_spec = torch.log10(self.floor_level + mel_spec)
62
+ return mel_spec.reshape(B, C * freqs, frame)
63
+
64
+
65
+ class MelSpectrogramL1Loss(torch.nn.Module):
66
+ """L1 Loss on MelSpectrogram.
67
+
68
+ Args:
69
+ sample_rate (int): Sample rate.
70
+ n_fft (int): Number of fft.
71
+ hop_length (int): Hop size.
72
+ win_length (int): Window length.
73
+ n_mels (int): Number of mel bins.
74
+ f_min (float or None): Minimum frequency.
75
+ f_max (float or None): Maximum frequency.
76
+ log (bool): Whether to scale with log.
77
+ normalized (bool): Whether to normalize the melspectrogram.
78
+ floor_level (float): Floor level value based on human perception (default=1e-5).
79
+ """
80
+ def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024,
81
+ n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None,
82
+ log: bool = True, normalized: bool = False, floor_level: float = 1e-5):
83
+ super().__init__()
84
+ self.l1 = torch.nn.L1Loss()
85
+ self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length,
86
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
87
+ log=log, normalized=normalized, floor_level=floor_level)
88
+
89
+ def forward(self, x, y):
90
+ self.melspec.to(x.device)
91
+ s_x = self.melspec(x)
92
+ s_y = self.melspec(y)
93
+ return self.l1(s_x, s_y)
94
+
95
+
96
+ class MultiScaleMelSpectrogramLoss(nn.Module):
97
+ """Multi-Scale spectrogram loss (msspec).
98
+
99
+ Args:
100
+ sample_rate (int): Sample rate.
101
+ range_start (int): Power of 2 to use for the first scale.
102
+ range_stop (int): Power of 2 to use for the last scale.
103
+ n_mels (int): Number of mel bins.
104
+ f_min (float): Minimum frequency.
105
+ f_max (float or None): Maximum frequency.
106
+ normalized (bool): Whether to normalize the melspectrogram.
107
+ alphas (bool): Whether to use alphas as coefficients or not.
108
+ floor_level (float): Floor level value based on human perception (default=1e-5).
109
+ """
110
+ def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11,
111
+ n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None,
112
+ normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5):
113
+ super().__init__()
114
+ l1s = list()
115
+ l2s = list()
116
+ self.alphas = list()
117
+ self.total = 0
118
+ self.normalized = normalized
119
+ for i in range(range_start, range_end):
120
+ l1s.append(
121
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
122
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
123
+ log=False, normalized=normalized, floor_level=floor_level))
124
+ l2s.append(
125
+ MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i,
126
+ n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max,
127
+ log=True, normalized=normalized, floor_level=floor_level))
128
+ if alphas:
129
+ self.alphas.append(np.sqrt(2 ** i - 1))
130
+ else:
131
+ self.alphas.append(1)
132
+ self.total += self.alphas[-1] + 1
133
+
134
+ self.l1s = nn.ModuleList(l1s)
135
+ self.l2s = nn.ModuleList(l2s)
136
+
137
+ def forward(self, x, y):
138
+ loss = 0.0
139
+ self.l1s.to(x.device)
140
+ self.l2s.to(x.device)
141
+ for i in range(len(self.alphas)):
142
+ s_x_1 = self.l1s[i](x)
143
+ s_y_1 = self.l1s[i](y)
144
+ s_x_2 = self.l2s[i](x)
145
+ s_y_2 = self.l2s[i](y)
146
+ loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2)
147
+ if self.normalized:
148
+ loss = loss / self.total
149
+ return loss
audiocraft/losses/stftloss.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # Adapted from MIT code under the original license
7
+ # Copyright 2019 Tomoki Hayashi
8
+ # MIT License (https://opensource.org/licenses/MIT)
9
+ import typing as tp
10
+
11
+ import torch
12
+ from torch import nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ # TODO: Replace with torchaudio.STFT?
17
+ def _stft(x: torch.Tensor, fft_size: int, hop_length: int, win_length: int,
18
+ window: tp.Optional[torch.Tensor], normalized: bool) -> torch.Tensor:
19
+ """Perform STFT and convert to magnitude spectrogram.
20
+
21
+ Args:
22
+ x: Input signal tensor (B, C, T).
23
+ fft_size (int): FFT size.
24
+ hop_length (int): Hop size.
25
+ win_length (int): Window length.
26
+ window (torch.Tensor or None): Window function type.
27
+ normalized (bool): Whether to normalize the STFT or not.
28
+
29
+ Returns:
30
+ torch.Tensor: Magnitude spectrogram (B, C, #frames, fft_size // 2 + 1).
31
+ """
32
+ B, C, T = x.shape
33
+ x_stft = torch.stft(
34
+ x.view(-1, T), fft_size, hop_length, win_length, window,
35
+ normalized=normalized, return_complex=True,
36
+ )
37
+ x_stft = x_stft.view(B, C, *x_stft.shape[1:])
38
+ real = x_stft.real
39
+ imag = x_stft.imag
40
+
41
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
42
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
43
+
44
+
45
+ class SpectralConvergenceLoss(nn.Module):
46
+ """Spectral convergence loss.
47
+ """
48
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
49
+ super().__init__()
50
+ self.epsilon = epsilon
51
+
52
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
53
+ """Calculate forward propagation.
54
+
55
+ Args:
56
+ x_mag: Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
57
+ y_mag: Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
58
+ Returns:
59
+ torch.Tensor: Spectral convergence loss value.
60
+ """
61
+ return torch.norm(y_mag - x_mag, p="fro") / (torch.norm(y_mag, p="fro") + self.epsilon)
62
+
63
+
64
+ class LogSTFTMagnitudeLoss(nn.Module):
65
+ """Log STFT magnitude loss.
66
+
67
+ Args:
68
+ epsilon (float): Epsilon value for numerical stability.
69
+ """
70
+ def __init__(self, epsilon: float = torch.finfo(torch.float32).eps):
71
+ super().__init__()
72
+ self.epsilon = epsilon
73
+
74
+ def forward(self, x_mag: torch.Tensor, y_mag: torch.Tensor):
75
+ """Calculate forward propagation.
76
+
77
+ Args:
78
+ x_mag (torch.Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
79
+ y_mag (torch.Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
80
+ Returns:
81
+ torch.Tensor: Log STFT magnitude loss value.
82
+ """
83
+ return F.l1_loss(torch.log(self.epsilon + y_mag), torch.log(self.epsilon + x_mag))
84
+
85
+
86
+ class STFTLosses(nn.Module):
87
+ """STFT losses.
88
+
89
+ Args:
90
+ n_fft (int): Size of FFT.
91
+ hop_length (int): Hop length.
92
+ win_length (int): Window length.
93
+ window (str): Window function type.
94
+ normalized (bool): Whether to use normalized STFT or not.
95
+ epsilon (float): Epsilon for numerical stability.
96
+ """
97
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
98
+ window: str = "hann_window", normalized: bool = False,
99
+ epsilon: float = torch.finfo(torch.float32).eps):
100
+ super().__init__()
101
+ self.n_fft = n_fft
102
+ self.hop_length = hop_length
103
+ self.win_length = win_length
104
+ self.normalized = normalized
105
+ self.register_buffer("window", getattr(torch, window)(win_length))
106
+ self.spectral_convergenge_loss = SpectralConvergenceLoss(epsilon)
107
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss(epsilon)
108
+
109
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
110
+ """Calculate forward propagation.
111
+
112
+ Args:
113
+ x (torch.Tensor): Predicted signal (B, T).
114
+ y (torch.Tensor): Groundtruth signal (B, T).
115
+ Returns:
116
+ torch.Tensor: Spectral convergence loss value.
117
+ torch.Tensor: Log STFT magnitude loss value.
118
+ """
119
+ x_mag = _stft(x, self.n_fft, self.hop_length,
120
+ self.win_length, self.window, self.normalized) # type: ignore
121
+ y_mag = _stft(y, self.n_fft, self.hop_length,
122
+ self.win_length, self.window, self.normalized) # type: ignore
123
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
124
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
125
+
126
+ return sc_loss, mag_loss
127
+
128
+
129
+ class STFTLoss(nn.Module):
130
+ """Single Resolution STFT loss.
131
+
132
+ Args:
133
+ n_fft (int): Nb of FFT.
134
+ hop_length (int): Hop length.
135
+ win_length (int): Window length.
136
+ window (str): Window function type.
137
+ normalized (bool): Whether to use normalized STFT or not.
138
+ epsilon (float): Epsilon for numerical stability.
139
+ factor_sc (float): Coefficient for the spectral loss.
140
+ factor_mag (float): Coefficient for the magnitude loss.
141
+ """
142
+ def __init__(self, n_fft: int = 1024, hop_length: int = 120, win_length: int = 600,
143
+ window: str = "hann_window", normalized: bool = False,
144
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
145
+ epsilon: float = torch.finfo(torch.float32).eps):
146
+ super().__init__()
147
+ self.loss = STFTLosses(n_fft, hop_length, win_length, window, normalized, epsilon)
148
+ self.factor_sc = factor_sc
149
+ self.factor_mag = factor_mag
150
+
151
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
152
+ """Calculate forward propagation.
153
+
154
+ Args:
155
+ x (torch.Tensor): Predicted signal (B, T).
156
+ y (torch.Tensor): Groundtruth signal (B, T).
157
+ Returns:
158
+ torch.Tensor: Single resolution STFT loss.
159
+ """
160
+ sc_loss, mag_loss = self.loss(x, y)
161
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
162
+
163
+
164
+ class MRSTFTLoss(nn.Module):
165
+ """Multi resolution STFT loss.
166
+
167
+ Args:
168
+ n_ffts (Sequence[int]): Sequence of FFT sizes.
169
+ hop_lengths (Sequence[int]): Sequence of hop sizes.
170
+ win_lengths (Sequence[int]): Sequence of window lengths.
171
+ window (str): Window function type.
172
+ factor_sc (float): Coefficient for the spectral loss.
173
+ factor_mag (float): Coefficient for the magnitude loss.
174
+ normalized (bool): Whether to use normalized STFT or not.
175
+ epsilon (float): Epsilon for numerical stability.
176
+ """
177
+ def __init__(self, n_ffts: tp.Sequence[int] = [1024, 2048, 512], hop_lengths: tp.Sequence[int] = [120, 240, 50],
178
+ win_lengths: tp.Sequence[int] = [600, 1200, 240], window: str = "hann_window",
179
+ factor_sc: float = 0.1, factor_mag: float = 0.1,
180
+ normalized: bool = False, epsilon: float = torch.finfo(torch.float32).eps):
181
+ super().__init__()
182
+ assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
183
+ self.stft_losses = torch.nn.ModuleList()
184
+ for fs, ss, wl in zip(n_ffts, hop_lengths, win_lengths):
185
+ self.stft_losses += [STFTLosses(fs, ss, wl, window, normalized, epsilon)]
186
+ self.factor_sc = factor_sc
187
+ self.factor_mag = factor_mag
188
+
189
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
190
+ """Calculate forward propagation.
191
+
192
+ Args:
193
+ x (torch.Tensor): Predicted signal (B, T).
194
+ y (torch.Tensor): Groundtruth signal (B, T).
195
+ Returns:
196
+ torch.Tensor: Multi resolution STFT loss.
197
+ """
198
+ sc_loss = torch.Tensor([0.0])
199
+ mag_loss = torch.Tensor([0.0])
200
+ for f in self.stft_losses:
201
+ sc_l, mag_l = f(x, y)
202
+ sc_loss += sc_l
203
+ mag_loss += mag_l
204
+ sc_loss /= len(self.stft_losses)
205
+ mag_loss /= len(self.stft_losses)
206
+
207
+ return self.factor_sc * sc_loss + self.factor_mag * mag_loss
audiocraft/metrics/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc.
7
+ """
8
+ # flake8: noqa
9
+ from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric
10
+ from .chroma_cosinesim import ChromaCosineSimilarityMetric
11
+ from .fad import FrechetAudioDistanceMetric
12
+ from .kld import KLDivergenceMetric, PasstKLDivergenceMetric
13
+ from .rvm import RelativeVolumeMel
14
+ from .visqol import ViSQOL
audiocraft/metrics/chroma_cosinesim.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torchmetrics
9
+
10
+ from ..data.audio_utils import convert_audio
11
+ from ..modules.chroma import ChromaExtractor
12
+
13
+
14
+ class ChromaCosineSimilarityMetric(torchmetrics.Metric):
15
+ """Chroma cosine similarity metric.
16
+
17
+ This metric extracts a chromagram for a reference waveform and
18
+ a generated waveform and compares each frame using the cosine similarity
19
+ function. The output is the mean cosine similarity.
20
+
21
+ Args:
22
+ sample_rate (int): Sample rate used by the chroma extractor.
23
+ n_chroma (int): Number of chroma used by the chroma extractor.
24
+ radix2_exp (int): Exponent for the chroma extractor.
25
+ argmax (bool): Whether the chroma extractor uses argmax.
26
+ eps (float): Epsilon for cosine similarity computation.
27
+ """
28
+ def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
29
+ super().__init__()
30
+ self.chroma_sample_rate = sample_rate
31
+ self.n_chroma = n_chroma
32
+ self.eps = eps
33
+ self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
34
+ radix2_exp=radix2_exp, argmax=argmax)
35
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
36
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
37
+
38
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
39
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
40
+ """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
41
+ if preds.size(0) == 0:
42
+ return
43
+
44
+ assert preds.shape == targets.shape, (
45
+ f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
46
+ assert preds.size(0) == sizes.size(0), (
47
+ f"Number of items in preds ({preds.shape}) mismatch ",
48
+ f"with sizes ({sizes.shape})")
49
+ assert preds.size(0) == sample_rates.size(0), (
50
+ f"Number of items in preds ({preds.shape}) mismatch ",
51
+ f"with sample_rates ({sample_rates.shape})")
52
+ assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
53
+
54
+ device = self.weight.device
55
+ preds, targets = preds.to(device), targets.to(device) # type: ignore
56
+ sample_rate = sample_rates[0].item()
57
+ preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
58
+ targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
59
+ gt_chroma = self.chroma_extractor(targets)
60
+ gen_chroma = self.chroma_extractor(preds)
61
+ chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
62
+ for i in range(len(gt_chroma)):
63
+ t = int(chroma_lens[i].item())
64
+ cosine_sim = torch.nn.functional.cosine_similarity(
65
+ gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
66
+ self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
67
+ self.weight += torch.tensor(t) # type: ignore
68
+
69
+ def compute(self) -> float:
70
+ """Computes the average cosine similarty across all generated/target chromagrams pairs."""
71
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
72
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/clap_consistency.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torchmetrics
12
+ from transformers import RobertaTokenizer # type: ignore
13
+
14
+ from ..data.audio_utils import convert_audio
15
+ from ..environment import AudioCraftEnvironment
16
+ from ..utils.utils import load_clap_state_dict
17
+
18
+ try:
19
+ import laion_clap # type: ignore
20
+ except ImportError:
21
+ laion_clap = None
22
+
23
+
24
+ class TextConsistencyMetric(torchmetrics.Metric):
25
+ """Text consistency metric measuring consistency between audio and text pairs."""
26
+
27
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
28
+ raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
29
+
30
+ def compute(self):
31
+ raise NotImplementedError("implement how to compute the final metric score.")
32
+
33
+
34
+ class CLAPTextConsistencyMetric(TextConsistencyMetric):
35
+ """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
36
+
37
+ This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
38
+ or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
39
+
40
+ As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
41
+ similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
42
+ well as the generated audio based on them, and define the MCC metric as the average cosine similarity
43
+ between these embeddings.
44
+
45
+ Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
46
+ """
47
+ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
48
+ super().__init__()
49
+ if laion_clap is None:
50
+ raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
51
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
52
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
53
+ self._initialize_model(model_path, model_arch, enable_fusion)
54
+
55
+ def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
56
+ model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
57
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
58
+ self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
59
+ self.model_sample_rate = 48_000
60
+ load_clap_state_dict(self.model, model_path)
61
+ self.model.eval()
62
+
63
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
64
+ # we use the default params from CLAP module here as well
65
+ return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
66
+
67
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
68
+ """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
69
+ assert audio.size(0) == len(text), "Number of audio and text samples should match"
70
+ assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
71
+ sample_rate = int(sample_rates[0].item())
72
+ # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
73
+ audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
74
+ audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
75
+ text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
76
+ # cosine similarity between the text and the audio embedding
77
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
78
+ self.cosine_sum += cosine_sim.sum(dim=0)
79
+ self.weight += torch.tensor(cosine_sim.size(0))
80
+
81
+ def compute(self):
82
+ """Computes the average cosine similarty across all audio/text pairs."""
83
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
84
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/fad.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ from audiocraft.data.audio import audio_write
15
+ from audiocraft.data.audio_utils import convert_audio
16
+ import flashy
17
+ import torch
18
+ import torchmetrics
19
+
20
+ from ..environment import AudioCraftEnvironment
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ VGGISH_SAMPLE_RATE = 16_000
26
+ VGGISH_CHANNELS = 1
27
+
28
+
29
+ class FrechetAudioDistanceMetric(torchmetrics.Metric):
30
+ """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
31
+
32
+ From: D.C. Dowson & B.V. Landau The Fréchet distance between
33
+ multivariate normal distributions
34
+ https://doi.org/10.1016/0047-259X(82)90077-X
35
+ The Fréchet distance between two multivariate gaussians,
36
+ `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
37
+ d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
38
+ = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
39
+ - 2 * Tr(sqrt(sigma_x*sigma_y)))
40
+
41
+ To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
42
+ from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
43
+ We provide the below instructions as reference but we do not guarantee for further support
44
+ in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
45
+
46
+ We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
47
+
48
+ 1. Get the code and models following the repository instructions. We used the steps below:
49
+ git clone git@github.com:google-research/google-research.git
50
+ git clone git@github.com:tensorflow/models.git
51
+ mkdir google-research/tensorflow_models
52
+ touch google-research/tensorflow_models/__init__.py
53
+ cp -r models/research/audioset google-research/tensorflow_models/
54
+ touch google-research/tensorflow_models/audioset/__init__.py
55
+ echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
56
+ google-research/tensorflow_models/audioset/__init__.py
57
+ # we can now remove the tensorflow models repository
58
+ # rm -r models
59
+ cd google-research
60
+ Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
61
+ assumes it is placed in the AudioCraft reference dir.
62
+
63
+ Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
64
+ - Update xrange for range in:
65
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
66
+ - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
67
+ `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
68
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
69
+ - Update `import vggish_params as params` to `from . import vggish_params as params` in:
70
+ https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
71
+ - Add flag to provide a given batch size for running the AudioSet model in:
72
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
73
+ ```
74
+ flags.DEFINE_integer('batch_size', 64,
75
+ 'Number of samples in the batch for AudioSet model.')
76
+ ```
77
+ Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
78
+ `batch_size=FLAGS.batch_size` to the provided parameters.
79
+
80
+ 2. Follow instructions for the library installation and a valid TensorFlow installation
81
+ ```
82
+ # e.g. instructions from: https://www.tensorflow.org/install/pip
83
+ conda install -c conda-forge cudatoolkit=11.8.0
84
+ python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
85
+ mkdir -p $CONDA_PREFIX/etc/conda/activate.d
86
+ echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
87
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
88
+ echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
89
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
90
+ source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
91
+ # Verify install: on a machine with GPU device
92
+ python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
93
+ ```
94
+
95
+ Now install frechet_audio_distance required dependencies:
96
+ ```
97
+ # We assume we already have TensorFlow installed from the above steps
98
+ pip install apache-beam numpy scipy tf_slim
99
+ ```
100
+
101
+ Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
102
+ (you may want to specify --model_ckpt flag pointing to the model's path).
103
+
104
+ 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
105
+ and Tensorflow library path from the above installation steps:
106
+ export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
107
+ export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
108
+
109
+ e.g. assuming we have installed everything in a dedicated conda env
110
+ with python 3.10 that is currently active:
111
+ export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
112
+ export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
113
+
114
+ Finally you may want to export the following variable:
115
+ export TF_FORCE_GPU_ALLOW_GROWTH=true
116
+ See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
117
+
118
+ You can save those environment variables in your training conda env, when currently active:
119
+ `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
120
+ e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
121
+ and the training conda env is named audiocraft:
122
+ ```
123
+ # activate training env
124
+ conda activate audiocraft
125
+ # get path to all envs
126
+ CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
127
+ # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
128
+ touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
129
+ echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
130
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
131
+ echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
132
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
133
+ # optionally:
134
+ echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
135
+ # you may need to reactivate the audiocraft env for this to take effect
136
+ ```
137
+
138
+ Args:
139
+ bin (Path or str): Path to installed frechet audio distance code.
140
+ model_path (Path or str): Path to Tensorflow checkpoint for the model
141
+ used to compute statistics over the embedding beams.
142
+ format (str): Audio format used to save files.
143
+ log_folder (Path or str, optional): Path where to write process logs.
144
+ """
145
+ def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
146
+ format: str = "wav", batch_size: tp.Optional[int] = None,
147
+ log_folder: tp.Optional[tp.Union[Path, str]] = None):
148
+ super().__init__()
149
+ self.model_sample_rate = VGGISH_SAMPLE_RATE
150
+ self.model_channels = VGGISH_CHANNELS
151
+ self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
152
+ assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
153
+ self.format = format
154
+ self.batch_size = batch_size
155
+ self.bin = bin
156
+ self.tf_env = {"PYTHONPATH": str(self.bin)}
157
+ self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
158
+ logger.info("Python exe for TF is %s", self.python_path)
159
+ if 'TF_LIBRARY_PATH' in os.environ:
160
+ self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
161
+ if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
162
+ self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
163
+ logger.info("Env for TF is %r", self.tf_env)
164
+ self.reset(log_folder)
165
+ self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
166
+
167
+ def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
168
+ """Reset torchmetrics.Metrics state."""
169
+ log_folder = Path(log_folder or tempfile.mkdtemp())
170
+ self.tmp_dir = log_folder / 'fad'
171
+ self.tmp_dir.mkdir(exist_ok=True)
172
+ self.samples_tests_dir = self.tmp_dir / 'tests'
173
+ self.samples_tests_dir.mkdir(exist_ok=True)
174
+ self.samples_background_dir = self.tmp_dir / 'background'
175
+ self.samples_background_dir.mkdir(exist_ok=True)
176
+ self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
177
+ self.manifest_background = self.tmp_dir / 'files_background.cvs'
178
+ self.stats_tests_dir = self.tmp_dir / 'stats_tests'
179
+ self.stats_background_dir = self.tmp_dir / 'stats_background'
180
+ self.counter = 0
181
+
182
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
183
+ sizes: torch.Tensor, sample_rates: torch.Tensor,
184
+ stems: tp.Optional[tp.List[str]] = None):
185
+ """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
186
+ assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
187
+ num_samples = preds.shape[0]
188
+ assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
189
+ assert stems is None or num_samples == len(set(stems))
190
+ for i in range(num_samples):
191
+ self.total_files += 1 # type: ignore
192
+ self.counter += 1
193
+ wav_len = int(sizes[i].item())
194
+ sample_rate = int(sample_rates[i].item())
195
+ pred_wav = preds[i]
196
+ target_wav = targets[i]
197
+ pred_wav = pred_wav[..., :wav_len]
198
+ target_wav = target_wav[..., :wav_len]
199
+ stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
200
+ # dump audio files
201
+ try:
202
+ pred_wav = convert_audio(
203
+ pred_wav.unsqueeze(0), from_rate=sample_rate,
204
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
205
+ audio_write(
206
+ self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
207
+ format=self.format, strategy="peak")
208
+ except Exception as e:
209
+ logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
210
+ try:
211
+ # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
212
+ # the original audio when writing it
213
+ target_wav = convert_audio(
214
+ target_wav.unsqueeze(0), from_rate=sample_rate,
215
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
216
+ audio_write(
217
+ self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
218
+ format=self.format, strategy="peak")
219
+ except Exception as e:
220
+ logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
221
+
222
+ def _get_samples_name(self, is_background: bool):
223
+ return 'background' if is_background else 'tests'
224
+
225
+ def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
226
+ if is_background:
227
+ input_samples_dir = self.samples_background_dir
228
+ input_filename = self.manifest_background
229
+ stats_name = self.stats_background_dir
230
+ else:
231
+ input_samples_dir = self.samples_tests_dir
232
+ input_filename = self.manifest_tests
233
+ stats_name = self.stats_tests_dir
234
+ beams_name = self._get_samples_name(is_background)
235
+ log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
236
+
237
+ logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
238
+ with open(input_filename, "w") as fout:
239
+ for path in Path(input_samples_dir).glob(f"*.{self.format}"):
240
+ fout.write(f"{str(path)}\n")
241
+
242
+ cmd = [
243
+ self.python_path, "-m",
244
+ "frechet_audio_distance.create_embeddings_main",
245
+ "--model_ckpt", f"{self.model_path}",
246
+ "--input_files", f"{str(input_filename)}",
247
+ "--stats", f"{str(stats_name)}",
248
+ ]
249
+ if self.batch_size is not None:
250
+ cmd += ["--batch_size", str(self.batch_size)]
251
+ logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
252
+ env = os.environ
253
+ if gpu_index is not None:
254
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
255
+ process = subprocess.Popen(
256
+ cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
257
+ return process, log_file
258
+
259
+ def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
260
+ cmd = [
261
+ self.python_path, "-m", "frechet_audio_distance.compute_fad",
262
+ "--test_stats", f"{str(self.stats_tests_dir)}",
263
+ "--background_stats", f"{str(self.stats_background_dir)}",
264
+ ]
265
+ logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
266
+ env = os.environ
267
+ if gpu_index is not None:
268
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
269
+ result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
270
+ if result.returncode:
271
+ logger.error(
272
+ "Error with FAD computation from stats: \n %s \n %s",
273
+ result.stdout.decode(), result.stderr.decode()
274
+ )
275
+ raise RuntimeError("Error while executing FAD computation from stats")
276
+ try:
277
+ # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
278
+ fad_score = float(result.stdout[4:])
279
+ return fad_score
280
+ except Exception as e:
281
+ raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
282
+
283
+ def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
284
+ beams_name = self._get_samples_name(is_background)
285
+ if returncode:
286
+ with open(log_file, "r") as f:
287
+ error_log = f.read()
288
+ logger.error(error_log)
289
+ os._exit(1)
290
+ else:
291
+ logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
292
+
293
+ def _parallel_create_embedding_beams(self, num_of_gpus: int):
294
+ assert num_of_gpus > 0
295
+ logger.info("Creating embeddings beams in a parallel manner on different GPUs")
296
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
297
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
298
+ tests_beams_code = tests_beams_process.wait()
299
+ bg_beams_code = bg_beams_process.wait()
300
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
301
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
302
+
303
+ def _sequential_create_embedding_beams(self):
304
+ logger.info("Creating embeddings beams in a sequential manner")
305
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
306
+ tests_beams_code = tests_beams_process.wait()
307
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
308
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
309
+ bg_beams_code = bg_beams_process.wait()
310
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
311
+
312
+ @flashy.distrib.rank_zero_only
313
+ def _local_compute_frechet_audio_distance(self):
314
+ """Compute Frechet Audio Distance score calling TensorFlow API."""
315
+ num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
316
+ if num_of_gpus > 1:
317
+ self._parallel_create_embedding_beams(num_of_gpus)
318
+ else:
319
+ self._sequential_create_embedding_beams()
320
+ fad_score = self._compute_fad_score(gpu_index=0)
321
+ return fad_score
322
+
323
+ def compute(self) -> float:
324
+ """Compute metrics."""
325
+ assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
326
+ fad_score = self._local_compute_frechet_audio_distance()
327
+ logger.warning(f"FAD score = {fad_score}")
328
+ fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
329
+ return fad_score
audiocraft/metrics/kld.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ from functools import partial
9
+ import logging
10
+ import os
11
+ import typing as tp
12
+
13
+ import torch
14
+ import torchmetrics
15
+
16
+ from ..data.audio_utils import convert_audio
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class _patch_passt_stft:
23
+ """Decorator to patch torch.stft in PaSST."""
24
+ def __init__(self):
25
+ self.old_stft = torch.stft
26
+
27
+ def __enter__(self):
28
+ # return_complex is a mandatory parameter in latest torch versions
29
+ # torch is throwing RuntimeErrors when not set
30
+ torch.stft = partial(torch.stft, return_complex=False)
31
+
32
+ def __exit__(self, *exc):
33
+ torch.stft = self.old_stft
34
+
35
+
36
+ def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
37
+ """Computes the elementwise KL-Divergence loss between probability distributions
38
+ from generated samples and target samples.
39
+
40
+ Args:
41
+ pred_probs (torch.Tensor): Probabilities for each label obtained
42
+ from a classifier on generated audio. Expected shape is [B, num_classes].
43
+ target_probs (torch.Tensor): Probabilities for each label obtained
44
+ from a classifier on target audio. Expected shape is [B, num_classes].
45
+ epsilon (float): Epsilon value.
46
+ Returns:
47
+ kld (torch.Tensor): KLD loss between each generated sample and target pair.
48
+ """
49
+ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
50
+ return kl_div.sum(-1)
51
+
52
+
53
+ class KLDivergenceMetric(torchmetrics.Metric):
54
+ """Base implementation for KL Divergence metric.
55
+
56
+ The KL divergence is measured between probability distributions
57
+ of class predictions returned by a pre-trained audio classification model.
58
+ When the KL-divergence is low, the generated audio is expected to
59
+ have similar acoustic characteristics as the reference audio,
60
+ according to the classifier.
61
+ """
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
65
+ self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
66
+ self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
67
+ self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
68
+
69
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
70
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
71
+ """Get model output given provided input tensor.
72
+
73
+ Args:
74
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
75
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
76
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
77
+ Returns:
78
+ probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
79
+ """
80
+ raise NotImplementedError("implement method to extract label distributions from the model.")
81
+
82
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
83
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
84
+ """Calculates running KL-Divergence loss between batches of audio
85
+ preds (generated) and target (ground-truth)
86
+ Args:
87
+ preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
88
+ targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
89
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
90
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
91
+ """
92
+ assert preds.shape == targets.shape
93
+ assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
94
+ preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
95
+ targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
96
+ if preds_probs is not None and targets_probs is not None:
97
+ assert preds_probs.shape == targets_probs.shape
98
+ kld_scores = kl_divergence(preds_probs, targets_probs)
99
+ assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
100
+ self.kld_pq_sum += torch.sum(kld_scores)
101
+ kld_qp_scores = kl_divergence(targets_probs, preds_probs)
102
+ self.kld_qp_sum += torch.sum(kld_qp_scores)
103
+ self.weight += torch.tensor(kld_scores.size(0))
104
+
105
+ def compute(self) -> dict:
106
+ """Computes KL-Divergence across all evaluated pred/target pairs."""
107
+ weight: float = float(self.weight.item()) # type: ignore
108
+ assert weight > 0, "Unable to compute with total number of comparisons <= 0"
109
+ logger.info(f"Computing KL divergence on a total of {weight} samples")
110
+ kld_pq = self.kld_pq_sum.item() / weight # type: ignore
111
+ kld_qp = self.kld_qp_sum.item() / weight # type: ignore
112
+ kld_both = kld_pq + kld_qp
113
+ return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
114
+
115
+
116
+ class PasstKLDivergenceMetric(KLDivergenceMetric):
117
+ """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
118
+
119
+ From: PaSST: Efficient Training of Audio Transformers with Patchout
120
+ Paper: https://arxiv.org/abs/2110.05069
121
+ Implementation: https://github.com/kkoutini/PaSST
122
+
123
+ Follow instructions from the github repo:
124
+ ```
125
+ pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
126
+ ```
127
+
128
+ Args:
129
+ pretrained_length (float, optional): Audio duration used for the pretrained model.
130
+ """
131
+ def __init__(self, pretrained_length: tp.Optional[float] = None):
132
+ super().__init__()
133
+ self._initialize_model(pretrained_length)
134
+
135
+ def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
136
+ """Initialize underlying PaSST audio classifier."""
137
+ model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
138
+ self.min_input_frames = min_frames
139
+ self.max_input_frames = max_frames
140
+ self.model_sample_rate = sr
141
+ self.model = model
142
+ self.model.eval()
143
+ self.model.to(self.device)
144
+
145
+ def _load_base_model(self, pretrained_length: tp.Optional[float]):
146
+ """Load pretrained model from PaSST."""
147
+ try:
148
+ if pretrained_length == 30:
149
+ from hear21passt.base30sec import get_basic_model # type: ignore
150
+ max_duration = 30
151
+ elif pretrained_length == 20:
152
+ from hear21passt.base20sec import get_basic_model # type: ignore
153
+ max_duration = 20
154
+ else:
155
+ from hear21passt.base import get_basic_model # type: ignore
156
+ # Original PASST was trained on AudioSet with 10s-long audio samples
157
+ max_duration = 10
158
+ min_duration = 0.15
159
+ min_duration = 0.15
160
+ except ModuleNotFoundError:
161
+ raise ModuleNotFoundError(
162
+ "Please install hear21passt to compute KL divergence: ",
163
+ "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
164
+ )
165
+ model_sample_rate = 32_000
166
+ max_input_frames = int(max_duration * model_sample_rate)
167
+ min_input_frames = int(min_duration * model_sample_rate)
168
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
169
+ model = get_basic_model(mode='logits')
170
+ return model, model_sample_rate, max_input_frames, min_input_frames
171
+
172
+ def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
173
+ """Process audio to feed to the pretrained model."""
174
+ wav = wav.unsqueeze(0)
175
+ wav = wav[..., :wav_len]
176
+ wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
177
+ wav = wav.squeeze(0)
178
+ # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
179
+ segments = torch.split(wav, self.max_input_frames, dim=-1)
180
+ valid_segments = []
181
+ for s in segments:
182
+ # ignoring too small segments that are breaking the model inference
183
+ if s.size(-1) > self.min_input_frames:
184
+ valid_segments.append(s)
185
+ return [s[None] for s in valid_segments]
186
+
187
+ def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
188
+ """Run the pretrained model and get the predictions."""
189
+ assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
190
+ wav = wav.mean(dim=1)
191
+ # PaSST is printing a lot of garbage that we are not interested in
192
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
193
+ with torch.no_grad(), _patch_passt_stft():
194
+ logits = self.model(wav.to(self.device))
195
+ probs = torch.softmax(logits, dim=-1)
196
+ return probs
197
+
198
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
199
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
200
+ """Get model output given provided input tensor.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
204
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
205
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
206
+ Returns:
207
+ probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
208
+ """
209
+ all_probs: tp.List[torch.Tensor] = []
210
+ for i, wav in enumerate(x):
211
+ sample_rate = int(sample_rates[i].item())
212
+ wav_len = int(sizes[i].item())
213
+ wav_segments = self._process_audio(wav, sample_rate, wav_len)
214
+ for segment in wav_segments:
215
+ probs = self._get_model_preds(segment).mean(dim=0)
216
+ all_probs.append(probs)
217
+ if len(all_probs) > 0:
218
+ return torch.stack(all_probs, dim=0)
219
+ else:
220
+ return None
audiocraft/metrics/rvm.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+ import torch
9
+ from torch import nn
10
+ import torchaudio
11
+
12
+
13
+ def db_to_scale(volume: tp.Union[float, torch.Tensor]):
14
+ return 10 ** (volume / 20)
15
+
16
+
17
+ def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
18
+ min_scale = db_to_scale(min_volume)
19
+ return 20 * torch.log10(scale.clamp(min=min_scale))
20
+
21
+
22
+ class RelativeVolumeMel(nn.Module):
23
+ """Relative volume melspectrogram measure.
24
+
25
+ Computes a measure of distance over two mel spectrogram that is interpretable in terms
26
+ of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
27
+ first renormalize both by the ground truth of `x_ref`.
28
+
29
+ ..Warning:: This class returns the volume of the distortion at the spectrogram level,
30
+ e.g. low negative values reflects lower distortion levels. For a SNR (like reported
31
+ in the MultiBandDiffusion paper), just take `-rvm`.
32
+
33
+ Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
34
+ relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
35
+ clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
36
+ with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
37
+ Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
38
+ average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
39
+ good (for a neural network output, although sound engineers typically aim for much lower attenuations).
40
+ Similarly, anything above +30 dB would just be completely missing the target, and there is no point
41
+ in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
42
+ in line with what neural nets currently can achieve.
43
+
44
+ For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
45
+ the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
46
+
47
+ The metric can be aggregated over a given frequency band in order have different insights for
48
+ different region of the spectrum. `num_aggregated_bands` controls the number of bands.
49
+
50
+ ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
51
+ is numerically stable when computing its gradient. We thus advise against using it as a training loss.
52
+
53
+ Args:
54
+ sample_rate (int): Sample rate of the input audio.
55
+ n_mels (int): Number of mel bands to use.
56
+ n_fft (int): Number of frequency bins for the STFT.
57
+ hop_length (int): Hop length of the STFT and the mel-spectrogram.
58
+ min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
59
+ the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
60
+ max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
61
+ max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
62
+ to that amount, to avoid rescaling near silence. Given in dB.
63
+ min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
64
+ bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
65
+ and anything below that will be considered equally.
66
+ num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
67
+ For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
68
+ """
69
+ def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
70
+ hop_length: int = 128, min_relative_volume: float = -25,
71
+ max_relative_volume: float = 25, max_initial_gain: float = 25,
72
+ min_activity_volume: float = -25,
73
+ num_aggregated_bands: int = 4) -> None:
74
+ super().__init__()
75
+ self.melspec = torchaudio.transforms.MelSpectrogram(
76
+ n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
77
+ normalized=True, sample_rate=sample_rate, power=2)
78
+ self.min_relative_volume = min_relative_volume
79
+ self.max_relative_volume = max_relative_volume
80
+ self.max_initial_gain = max_initial_gain
81
+ self.min_activity_volume = min_activity_volume
82
+ self.num_aggregated_bands = num_aggregated_bands
83
+
84
+ def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
85
+ """Compute RVM metric between estimate and reference samples.
86
+
87
+ Args:
88
+ estimate (torch.Tensor): Estimate sample.
89
+ ground_truth (torch.Tensor): Reference sample.
90
+
91
+ Returns:
92
+ dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
93
+ for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
94
+ """
95
+ min_scale = db_to_scale(-self.max_initial_gain)
96
+ std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
97
+ z_gt = self.melspec(ground_truth / std).sqrt()
98
+ z_est = self.melspec(estimate / std).sqrt()
99
+
100
+ delta = z_gt - z_est
101
+ ref_db = scale_to_db(z_gt, self.min_activity_volume)
102
+ delta_db = scale_to_db(delta.abs(), min_volume=-120)
103
+ relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
104
+ dims = list(range(relative_db.dim()))
105
+ dims.remove(dims[-2])
106
+ losses_per_band = relative_db.mean(dim=dims)
107
+ aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
108
+ metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
109
+ metrics['rvm'] = losses_per_band.mean()
110
+ return metrics
audiocraft/metrics/visqol.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ import tempfile
12
+ import typing as tp
13
+ import subprocess
14
+ import shutil
15
+
16
+ import torch
17
+ import torchaudio
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ViSQOL:
23
+ """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
24
+
25
+ To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
26
+ instructions available in the open source repository: https://github.com/google/visqol
27
+
28
+ ViSQOL is capable of running in two modes:
29
+
30
+ Audio Mode:
31
+ When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
32
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
33
+ Audio mode uses support vector regression, with the maximum range at ~4.75.
34
+
35
+ Speech Mode:
36
+ When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
37
+ Input should be resampled to 16kHz.
38
+ As part of the speech mode processing, a root mean square implementation for voice activity detection
39
+ is performed on the reference signal to determine what parts of the signal have voice activity and
40
+ should therefore be included in the comparison. The signal is normalized before performing the voice
41
+ activity detection.
42
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
43
+ Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
44
+
45
+ For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
46
+
47
+ Args:
48
+ visqol_bin (str): Path to the ViSQOL binary.
49
+ mode (str): ViSQOL computation mode, expecting "audio" or "speech".
50
+ model (str): Name of the model to use for similarity to quality model.
51
+ debug (bool): Whether to also get debug metrics from ViSQOL or not.
52
+ """
53
+ SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
54
+ ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
55
+
56
+ def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
57
+ model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
58
+ assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
59
+ self.visqol_bin = str(bin)
60
+ self.visqol_mode = mode
61
+ self.target_sr = self._get_target_sr(self.visqol_mode)
62
+ self.model = model
63
+ self.debug = debug
64
+ assert Path(self.visqol_model).exists(), \
65
+ f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
66
+
67
+ def _get_target_sr(self, mode: str) -> int:
68
+ # returns target sampling rate for the corresponding ViSQOL mode.
69
+ if mode not in ViSQOL.SAMPLE_RATES_MODES:
70
+ raise ValueError(
71
+ f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
72
+ )
73
+ return ViSQOL.SAMPLE_RATES_MODES[mode]
74
+
75
+ def _prepare_files(
76
+ self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
77
+ ):
78
+ # prepare files for ViSQOL evaluation.
79
+ assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
80
+ assert len(ref_sig) == len(deg_sig), (
81
+ "Expects same number of ref and degraded inputs",
82
+ f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
83
+ )
84
+ # resample audio if needed
85
+ if sr != target_sr:
86
+ transform = torchaudio.transforms.Resample(sr, target_sr)
87
+ pad = int(0.5 * target_sr)
88
+ rs_ref = []
89
+ rs_deg = []
90
+ for i in range(len(ref_sig)):
91
+ rs_ref_i = transform(ref_sig[i])
92
+ rs_deg_i = transform(deg_sig[i])
93
+ if pad_with_silence:
94
+ rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
95
+ rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
96
+ rs_ref.append(rs_ref_i)
97
+ rs_deg.append(rs_deg_i)
98
+ ref_sig = torch.stack(rs_ref)
99
+ deg_sig = torch.stack(rs_deg)
100
+ # save audio chunks to tmp dir and create csv
101
+ tmp_dir = Path(tempfile.mkdtemp())
102
+ try:
103
+ tmp_input_csv_path = tmp_dir / "input.csv"
104
+ tmp_results_csv_path = tmp_dir / "results.csv"
105
+ tmp_debug_json_path = tmp_dir / "debug.json"
106
+ with open(tmp_input_csv_path, "w") as csv_file:
107
+ csv_writer = csv.writer(csv_file)
108
+ csv_writer.writerow(["reference", "degraded"])
109
+ for i in range(len(ref_sig)):
110
+ tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
111
+ tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
112
+ torchaudio.save(
113
+ tmp_ref_filename,
114
+ torch.clamp(ref_sig[i], min=-0.99, max=0.99),
115
+ sample_rate=target_sr,
116
+ bits_per_sample=16,
117
+ encoding="PCM_S"
118
+ )
119
+ torchaudio.save(
120
+ tmp_deg_filename,
121
+ torch.clamp(deg_sig[i], min=-0.99, max=0.99),
122
+ sample_rate=target_sr,
123
+ bits_per_sample=16,
124
+ encoding="PCM_S"
125
+ )
126
+ csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
127
+ return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
128
+ except Exception as e:
129
+ logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
130
+ return tmp_dir, None, None, None
131
+
132
+ def _flush_files(self, tmp_dir: tp.Union[Path, str]):
133
+ # flush tmp files used to compute ViSQOL.
134
+ shutil.rmtree(str(tmp_dir))
135
+
136
+ def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
137
+ # collect results for each evaluated pair and return averaged moslqo score.
138
+ with open(results_csv_path, "r") as csv_file:
139
+ reader = csv.DictReader(csv_file)
140
+ moslqo_scores = [float(row["moslqo"]) for row in reader]
141
+ if len(moslqo_scores) > 0:
142
+ return sum(moslqo_scores) / len(moslqo_scores)
143
+ else:
144
+ return 0.0
145
+
146
+ def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
147
+ # collect debug data for the visqol inference.
148
+ with open(debug_json_path, "r") as f:
149
+ data = json.load(f)
150
+ return data
151
+
152
+ @property
153
+ def visqol_model(self):
154
+ return f'{self.visqol_bin}/model/{self.model}'
155
+
156
+ def _run_visqol(
157
+ self,
158
+ input_csv_path: tp.Union[Path, str],
159
+ results_csv_path: tp.Union[Path, str],
160
+ debug_csv_path: tp.Optional[tp.Union[Path, str]],
161
+ ):
162
+ input_csv_path = str(input_csv_path)
163
+ results_csv_path = str(results_csv_path)
164
+ debug_csv_path = str(debug_csv_path)
165
+ cmd = [
166
+ f'{self.visqol_bin}/bazel-bin/visqol',
167
+ '--batch_input_csv', f'{input_csv_path}',
168
+ '--results_csv', f'{results_csv_path}'
169
+ ]
170
+ if debug_csv_path is not None:
171
+ cmd += ['--output_debug', f'{debug_csv_path}']
172
+ if self.visqol_mode == "speech":
173
+ cmd += ['--use_speech_mode']
174
+ cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
175
+ result = subprocess.run(cmd, capture_output=True)
176
+ if result.returncode:
177
+ logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
178
+ raise RuntimeError("Error while executing visqol")
179
+ result.check_returncode()
180
+
181
+ def __call__(
182
+ self,
183
+ ref_sig: torch.Tensor,
184
+ deg_sig: torch.Tensor,
185
+ sr: int,
186
+ pad_with_silence: bool = False,
187
+ ):
188
+ """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
189
+ Args:
190
+ ref_sig (torch.Tensor): Reference signals as [B, C, T].
191
+ deg_sig (torch.Tensor): Degraded signals as [B, C, T].
192
+ sr (int): Sample rate of the two audio signals.
193
+ pad_with_silence (bool): Whether to pad the file with silences as recommended
194
+ in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
195
+ Returns:
196
+ float: The ViSQOL score or mean score for the batch.
197
+ """
198
+ logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
199
+ tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
200
+ ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
201
+ )
202
+ try:
203
+ if input_csv and results_csv:
204
+ self._run_visqol(
205
+ input_csv,
206
+ results_csv,
207
+ debug_json if self.debug else None,
208
+ )
209
+ mosqol = self._collect_moslqo_score(results_csv)
210
+ return mosqol
211
+ else:
212
+ raise RuntimeError("Something unexpected happened when running VISQOL!")
213
+ except Exception as e:
214
+ logger.error("Exception occurred when running ViSQOL: %s", e)
215
+ finally:
216
+ self._flush_files(tmp_dir)