unpairedelectron07 commited on
Commit
26b4608
·
verified ·
1 Parent(s): f586664

Upload 6 files

Browse files
audiocraft/metrics/chroma_cosinesim.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torchmetrics
9
+
10
+ from ..data.audio_utils import convert_audio
11
+ from ..modules.chroma import ChromaExtractor
12
+
13
+
14
+ class ChromaCosineSimilarityMetric(torchmetrics.Metric):
15
+ """Chroma cosine similarity metric.
16
+
17
+ This metric extracts a chromagram for a reference waveform and
18
+ a generated waveform and compares each frame using the cosine similarity
19
+ function. The output is the mean cosine similarity.
20
+
21
+ Args:
22
+ sample_rate (int): Sample rate used by the chroma extractor.
23
+ n_chroma (int): Number of chroma used by the chroma extractor.
24
+ radix2_exp (int): Exponent for the chroma extractor.
25
+ argmax (bool): Whether the chroma extractor uses argmax.
26
+ eps (float): Epsilon for cosine similarity computation.
27
+ """
28
+ def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
29
+ super().__init__()
30
+ self.chroma_sample_rate = sample_rate
31
+ self.n_chroma = n_chroma
32
+ self.eps = eps
33
+ self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
34
+ radix2_exp=radix2_exp, argmax=argmax)
35
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
36
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
37
+
38
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
39
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
40
+ """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
41
+ if preds.size(0) == 0:
42
+ return
43
+
44
+ assert preds.shape == targets.shape, (
45
+ f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
46
+ assert preds.size(0) == sizes.size(0), (
47
+ f"Number of items in preds ({preds.shape}) mismatch ",
48
+ f"with sizes ({sizes.shape})")
49
+ assert preds.size(0) == sample_rates.size(0), (
50
+ f"Number of items in preds ({preds.shape}) mismatch ",
51
+ f"with sample_rates ({sample_rates.shape})")
52
+ assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
53
+
54
+ device = self.weight.device
55
+ preds, targets = preds.to(device), targets.to(device) # type: ignore
56
+ sample_rate = sample_rates[0].item()
57
+ preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
58
+ targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
59
+ gt_chroma = self.chroma_extractor(targets)
60
+ gen_chroma = self.chroma_extractor(preds)
61
+ chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
62
+ for i in range(len(gt_chroma)):
63
+ t = int(chroma_lens[i].item())
64
+ cosine_sim = torch.nn.functional.cosine_similarity(
65
+ gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
66
+ self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
67
+ self.weight += torch.tensor(t) # type: ignore
68
+
69
+ def compute(self) -> float:
70
+ """Computes the average cosine similarty across all generated/target chromagrams pairs."""
71
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
72
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/clap_consistency.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+ import typing as tp
9
+
10
+ import torch
11
+ import torchmetrics
12
+ from transformers import RobertaTokenizer # type: ignore
13
+
14
+ from ..data.audio_utils import convert_audio
15
+ from ..environment import AudioCraftEnvironment
16
+ from ..utils.utils import load_clap_state_dict
17
+
18
+ try:
19
+ import laion_clap # type: ignore
20
+ except ImportError:
21
+ laion_clap = None
22
+
23
+
24
+ class TextConsistencyMetric(torchmetrics.Metric):
25
+ """Text consistency metric measuring consistency between audio and text pairs."""
26
+
27
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
28
+ raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
29
+
30
+ def compute(self):
31
+ raise NotImplementedError("implement how to compute the final metric score.")
32
+
33
+
34
+ class CLAPTextConsistencyMetric(TextConsistencyMetric):
35
+ """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
36
+
37
+ This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
38
+ or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
39
+
40
+ As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
41
+ similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
42
+ well as the generated audio based on them, and define the MCC metric as the average cosine similarity
43
+ between these embeddings.
44
+
45
+ Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
46
+ """
47
+ def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
48
+ super().__init__()
49
+ if laion_clap is None:
50
+ raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
51
+ self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
52
+ self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
53
+ self._initialize_model(model_path, model_arch, enable_fusion)
54
+
55
+ def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
56
+ model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
57
+ self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
58
+ self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
59
+ self.model_sample_rate = 48_000
60
+ load_clap_state_dict(self.model, model_path)
61
+ self.model.eval()
62
+
63
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
64
+ # we use the default params from CLAP module here as well
65
+ return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
66
+
67
+ def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
68
+ """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
69
+ assert audio.size(0) == len(text), "Number of audio and text samples should match"
70
+ assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
71
+ sample_rate = int(sample_rates[0].item())
72
+ # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
73
+ audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
74
+ audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
75
+ text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
76
+ # cosine similarity between the text and the audio embedding
77
+ cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
78
+ self.cosine_sum += cosine_sim.sum(dim=0)
79
+ self.weight += torch.tensor(cosine_sim.size(0))
80
+
81
+ def compute(self):
82
+ """Computes the average cosine similarty across all audio/text pairs."""
83
+ assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
84
+ return (self.cosine_sum / self.weight).item() # type: ignore
audiocraft/metrics/fad.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ from pathlib import Path
9
+ import os
10
+ import subprocess
11
+ import tempfile
12
+ import typing as tp
13
+
14
+ from audiocraft.data.audio import audio_write
15
+ from audiocraft.data.audio_utils import convert_audio
16
+ import flashy
17
+ import torch
18
+ import torchmetrics
19
+
20
+ from ..environment import AudioCraftEnvironment
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ VGGISH_SAMPLE_RATE = 16_000
26
+ VGGISH_CHANNELS = 1
27
+
28
+
29
+ class FrechetAudioDistanceMetric(torchmetrics.Metric):
30
+ """Fréchet Audio Distance computation based on official TensorFlow implementation from Google Research.
31
+
32
+ From: D.C. Dowson & B.V. Landau The Fréchet distance between
33
+ multivariate normal distributions
34
+ https://doi.org/10.1016/0047-259X(82)90077-X
35
+ The Fréchet distance between two multivariate gaussians,
36
+ `X ~ N(mu_x, sigma_x)` and `Y ~ N(mu_y, sigma_y)`, is `d^2`.
37
+ d^2 = (mu_x - mu_y)^2 + Tr(sigma_x + sigma_y - 2 * sqrt(sigma_x*sigma_y))
38
+ = (mu_x - mu_y)^2 + Tr(sigma_x) + Tr(sigma_y)
39
+ - 2 * Tr(sqrt(sigma_x*sigma_y)))
40
+
41
+ To use this FAD computation metric, you need to have the proper Frechet Audio Distance tool setup
42
+ from: https://github.com/google-research/google-research/tree/master/frechet_audio_distance
43
+ We provide the below instructions as reference but we do not guarantee for further support
44
+ in frechet_audio_distance installation. This was tested with python 3.10, cuda 11.8, tensorflow 2.12.0.
45
+
46
+ We recommend installing the frechet_audio_distance library in a dedicated env (e.g. conda).
47
+
48
+ 1. Get the code and models following the repository instructions. We used the steps below:
49
+ git clone git@github.com:google-research/google-research.git
50
+ git clone git@github.com:tensorflow/models.git
51
+ mkdir google-research/tensorflow_models
52
+ touch google-research/tensorflow_models/__init__.py
53
+ cp -r models/research/audioset google-research/tensorflow_models/
54
+ touch google-research/tensorflow_models/audioset/__init__.py
55
+ echo "from .vggish import mel_features, vggish_params, vggish_slim" > \
56
+ google-research/tensorflow_models/audioset/__init__.py
57
+ # we can now remove the tensorflow models repository
58
+ # rm -r models
59
+ cd google-research
60
+ Follow the instructions to download the vggish checkpoint. AudioCraft base configuration
61
+ assumes it is placed in the AudioCraft reference dir.
62
+
63
+ Note that we operate the following changes for the code to work with TensorFlow 2.X and python 3:
64
+ - Update xrange for range in:
65
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/audioset_model.py
66
+ - Update `tf_record = tf.python_io.tf_record_iterator(filename).next()` to
67
+ `tf_record = tf.python_io.tf_record_iterator(filename).__next__()` in
68
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/fad_utils.py
69
+ - Update `import vggish_params as params` to `from . import vggish_params as params` in:
70
+ https://github.com/tensorflow/models/blob/master/research/audioset/vggish/vggish_slim.py
71
+ - Add flag to provide a given batch size for running the AudioSet model in:
72
+ https://github.com/google-research/google-research/blob/master/frechet_audio_distance/create_embeddings_main.py
73
+ ```
74
+ flags.DEFINE_integer('batch_size', 64,
75
+ 'Number of samples in the batch for AudioSet model.')
76
+ ```
77
+ Ensure you pass the flag to the create_embeddings_beam.create_pipeline function, adding:
78
+ `batch_size=FLAGS.batch_size` to the provided parameters.
79
+
80
+ 2. Follow instructions for the library installation and a valid TensorFlow installation
81
+ ```
82
+ # e.g. instructions from: https://www.tensorflow.org/install/pip
83
+ conda install -c conda-forge cudatoolkit=11.8.0
84
+ python3 -m pip install nvidia-cudnn-cu11==8.6.0.163 tensorflow==2.12.*
85
+ mkdir -p $CONDA_PREFIX/etc/conda/activate.d
86
+ echo 'CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))' \
87
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
88
+ echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/:$CUDNN_PATH/lib' \
89
+ >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
90
+ source $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
91
+ # Verify install: on a machine with GPU device
92
+ python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
93
+ ```
94
+
95
+ Now install frechet_audio_distance required dependencies:
96
+ ```
97
+ # We assume we already have TensorFlow installed from the above steps
98
+ pip install apache-beam numpy scipy tf_slim
99
+ ```
100
+
101
+ Finally, follow remaining library instructions to ensure you have a working frechet_audio_distance setup
102
+ (you may want to specify --model_ckpt flag pointing to the model's path).
103
+
104
+ 3. AudioCraft's FrechetAudioDistanceMetric requires 2 environment variables pointing to the python executable
105
+ and Tensorflow library path from the above installation steps:
106
+ export TF_PYTHON_EXE="<PATH_TO_THE_ENV_PYTHON_BINARY>"
107
+ export TF_LIBRARY_PATH="<PATH_TO_THE_ENV_CUDNN_LIBRARY>"
108
+
109
+ e.g. assuming we have installed everything in a dedicated conda env
110
+ with python 3.10 that is currently active:
111
+ export TF_PYTHON_EXE="$CONDA_PREFIX/bin/python"
112
+ export TF_LIBRARY_PATH="$CONDA_PREFIX/lib/python3.10/site-packages/nvidia/cudnn/lib"
113
+
114
+ Finally you may want to export the following variable:
115
+ export TF_FORCE_GPU_ALLOW_GROWTH=true
116
+ See: https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
117
+
118
+ You can save those environment variables in your training conda env, when currently active:
119
+ `$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh`
120
+ e.g. assuming the env with TensorFlow and frechet_audio_distance install is named ac_eval,
121
+ and the training conda env is named audiocraft:
122
+ ```
123
+ # activate training env
124
+ conda activate audiocraft
125
+ # get path to all envs
126
+ CONDA_ENV_DIR=$(dirname $CONDA_PREFIX)
127
+ # export pointers to evaluation env for using TensorFlow in FrechetAudioDistanceMetric
128
+ touch $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
129
+ echo 'export TF_PYTHON_EXE="$CONDA_ENV_DIR/ac_eval/bin/python"' >> \
130
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
131
+ echo 'export TF_LIBRARY_PATH="$CONDA_ENV_DIR/ac_eval/lib/python3.10/site-packages/nvidia/cudnn/lib"' >> \
132
+ $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
133
+ # optionally:
134
+ echo 'export TF_FORCE_GPU_ALLOW_GROWTH=true' >> $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
135
+ # you may need to reactivate the audiocraft env for this to take effect
136
+ ```
137
+
138
+ Args:
139
+ bin (Path or str): Path to installed frechet audio distance code.
140
+ model_path (Path or str): Path to Tensorflow checkpoint for the model
141
+ used to compute statistics over the embedding beams.
142
+ format (str): Audio format used to save files.
143
+ log_folder (Path or str, optional): Path where to write process logs.
144
+ """
145
+ def __init__(self, bin: tp.Union[Path, str], model_path: tp.Union[Path, str],
146
+ format: str = "wav", batch_size: tp.Optional[int] = None,
147
+ log_folder: tp.Optional[tp.Union[Path, str]] = None):
148
+ super().__init__()
149
+ self.model_sample_rate = VGGISH_SAMPLE_RATE
150
+ self.model_channels = VGGISH_CHANNELS
151
+ self.model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
152
+ assert Path(self.model_path).exists(), f"Could not find provided model checkpoint path at: {self.model_path}"
153
+ self.format = format
154
+ self.batch_size = batch_size
155
+ self.bin = bin
156
+ self.tf_env = {"PYTHONPATH": str(self.bin)}
157
+ self.python_path = os.environ.get('TF_PYTHON_EXE') or 'python'
158
+ logger.info("Python exe for TF is %s", self.python_path)
159
+ if 'TF_LIBRARY_PATH' in os.environ:
160
+ self.tf_env['LD_LIBRARY_PATH'] = os.environ['TF_LIBRARY_PATH']
161
+ if 'TF_FORCE_GPU_ALLOW_GROWTH' in os.environ:
162
+ self.tf_env['TF_FORCE_GPU_ALLOW_GROWTH'] = os.environ['TF_FORCE_GPU_ALLOW_GROWTH']
163
+ logger.info("Env for TF is %r", self.tf_env)
164
+ self.reset(log_folder)
165
+ self.add_state("total_files", default=torch.tensor(0.), dist_reduce_fx="sum")
166
+
167
+ def reset(self, log_folder: tp.Optional[tp.Union[Path, str]] = None):
168
+ """Reset torchmetrics.Metrics state."""
169
+ log_folder = Path(log_folder or tempfile.mkdtemp())
170
+ self.tmp_dir = log_folder / 'fad'
171
+ self.tmp_dir.mkdir(exist_ok=True)
172
+ self.samples_tests_dir = self.tmp_dir / 'tests'
173
+ self.samples_tests_dir.mkdir(exist_ok=True)
174
+ self.samples_background_dir = self.tmp_dir / 'background'
175
+ self.samples_background_dir.mkdir(exist_ok=True)
176
+ self.manifest_tests = self.tmp_dir / 'files_tests.cvs'
177
+ self.manifest_background = self.tmp_dir / 'files_background.cvs'
178
+ self.stats_tests_dir = self.tmp_dir / 'stats_tests'
179
+ self.stats_background_dir = self.tmp_dir / 'stats_background'
180
+ self.counter = 0
181
+
182
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
183
+ sizes: torch.Tensor, sample_rates: torch.Tensor,
184
+ stems: tp.Optional[tp.List[str]] = None):
185
+ """Update torchmetrics.Metrics by saving the audio and updating the manifest file."""
186
+ assert preds.shape == targets.shape, f"preds={preds.shape} != targets={targets.shape}"
187
+ num_samples = preds.shape[0]
188
+ assert num_samples == sizes.size(0) and num_samples == sample_rates.size(0)
189
+ assert stems is None or num_samples == len(set(stems))
190
+ for i in range(num_samples):
191
+ self.total_files += 1 # type: ignore
192
+ self.counter += 1
193
+ wav_len = int(sizes[i].item())
194
+ sample_rate = int(sample_rates[i].item())
195
+ pred_wav = preds[i]
196
+ target_wav = targets[i]
197
+ pred_wav = pred_wav[..., :wav_len]
198
+ target_wav = target_wav[..., :wav_len]
199
+ stem_name = stems[i] if stems is not None else f'sample_{self.counter}_{flashy.distrib.rank()}'
200
+ # dump audio files
201
+ try:
202
+ pred_wav = convert_audio(
203
+ pred_wav.unsqueeze(0), from_rate=sample_rate,
204
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
205
+ audio_write(
206
+ self.samples_tests_dir / stem_name, pred_wav, sample_rate=self.model_sample_rate,
207
+ format=self.format, strategy="peak")
208
+ except Exception as e:
209
+ logger.error(f"Exception occured when saving tests files for FAD computation: {repr(e)} - {e}")
210
+ try:
211
+ # for the ground truth audio, we enforce the 'peak' strategy to avoid modifying
212
+ # the original audio when writing it
213
+ target_wav = convert_audio(
214
+ target_wav.unsqueeze(0), from_rate=sample_rate,
215
+ to_rate=self.model_sample_rate, to_channels=1).squeeze(0)
216
+ audio_write(
217
+ self.samples_background_dir / stem_name, target_wav, sample_rate=self.model_sample_rate,
218
+ format=self.format, strategy="peak")
219
+ except Exception as e:
220
+ logger.error(f"Exception occured when saving background files for FAD computation: {repr(e)} - {e}")
221
+
222
+ def _get_samples_name(self, is_background: bool):
223
+ return 'background' if is_background else 'tests'
224
+
225
+ def _create_embedding_beams(self, is_background: bool, gpu_index: tp.Optional[int] = None):
226
+ if is_background:
227
+ input_samples_dir = self.samples_background_dir
228
+ input_filename = self.manifest_background
229
+ stats_name = self.stats_background_dir
230
+ else:
231
+ input_samples_dir = self.samples_tests_dir
232
+ input_filename = self.manifest_tests
233
+ stats_name = self.stats_tests_dir
234
+ beams_name = self._get_samples_name(is_background)
235
+ log_file = self.tmp_dir / f'fad_logs_create_beams_{beams_name}.log'
236
+
237
+ logger.info(f"Scanning samples folder to fetch list of files: {input_samples_dir}")
238
+ with open(input_filename, "w") as fout:
239
+ for path in Path(input_samples_dir).glob(f"*.{self.format}"):
240
+ fout.write(f"{str(path)}\n")
241
+
242
+ cmd = [
243
+ self.python_path, "-m",
244
+ "frechet_audio_distance.create_embeddings_main",
245
+ "--model_ckpt", f"{self.model_path}",
246
+ "--input_files", f"{str(input_filename)}",
247
+ "--stats", f"{str(stats_name)}",
248
+ ]
249
+ if self.batch_size is not None:
250
+ cmd += ["--batch_size", str(self.batch_size)]
251
+ logger.info(f"Launching frechet_audio_distance embeddings main method: {' '.join(cmd)} on {beams_name}")
252
+ env = os.environ
253
+ if gpu_index is not None:
254
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
255
+ process = subprocess.Popen(
256
+ cmd, stdout=open(log_file, "w"), env={**env, **self.tf_env}, stderr=subprocess.STDOUT)
257
+ return process, log_file
258
+
259
+ def _compute_fad_score(self, gpu_index: tp.Optional[int] = None):
260
+ cmd = [
261
+ self.python_path, "-m", "frechet_audio_distance.compute_fad",
262
+ "--test_stats", f"{str(self.stats_tests_dir)}",
263
+ "--background_stats", f"{str(self.stats_background_dir)}",
264
+ ]
265
+ logger.info(f"Launching frechet_audio_distance compute fad method: {' '.join(cmd)}")
266
+ env = os.environ
267
+ if gpu_index is not None:
268
+ env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
269
+ result = subprocess.run(cmd, env={**env, **self.tf_env}, capture_output=True)
270
+ if result.returncode:
271
+ logger.error(
272
+ "Error with FAD computation from stats: \n %s \n %s",
273
+ result.stdout.decode(), result.stderr.decode()
274
+ )
275
+ raise RuntimeError("Error while executing FAD computation from stats")
276
+ try:
277
+ # result is "FAD: (d+).(d+)" hence we remove the prefix with (d+) being one digit or more
278
+ fad_score = float(result.stdout[4:])
279
+ return fad_score
280
+ except Exception as e:
281
+ raise RuntimeError(f"Error parsing FAD score from command stdout: {e}")
282
+
283
+ def _log_process_result(self, returncode: int, log_file: tp.Union[Path, str], is_background: bool) -> None:
284
+ beams_name = self._get_samples_name(is_background)
285
+ if returncode:
286
+ with open(log_file, "r") as f:
287
+ error_log = f.read()
288
+ logger.error(error_log)
289
+ os._exit(1)
290
+ else:
291
+ logger.info(f"Successfully computed embedding beams on {beams_name} samples.")
292
+
293
+ def _parallel_create_embedding_beams(self, num_of_gpus: int):
294
+ assert num_of_gpus > 0
295
+ logger.info("Creating embeddings beams in a parallel manner on different GPUs")
296
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False, gpu_index=0)
297
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True, gpu_index=1)
298
+ tests_beams_code = tests_beams_process.wait()
299
+ bg_beams_code = bg_beams_process.wait()
300
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
301
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
302
+
303
+ def _sequential_create_embedding_beams(self):
304
+ logger.info("Creating embeddings beams in a sequential manner")
305
+ tests_beams_process, tests_beams_log_file = self._create_embedding_beams(is_background=False)
306
+ tests_beams_code = tests_beams_process.wait()
307
+ self._log_process_result(tests_beams_code, tests_beams_log_file, is_background=False)
308
+ bg_beams_process, bg_beams_log_file = self._create_embedding_beams(is_background=True)
309
+ bg_beams_code = bg_beams_process.wait()
310
+ self._log_process_result(bg_beams_code, bg_beams_log_file, is_background=True)
311
+
312
+ @flashy.distrib.rank_zero_only
313
+ def _local_compute_frechet_audio_distance(self):
314
+ """Compute Frechet Audio Distance score calling TensorFlow API."""
315
+ num_of_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
316
+ if num_of_gpus > 1:
317
+ self._parallel_create_embedding_beams(num_of_gpus)
318
+ else:
319
+ self._sequential_create_embedding_beams()
320
+ fad_score = self._compute_fad_score(gpu_index=0)
321
+ return fad_score
322
+
323
+ def compute(self) -> float:
324
+ """Compute metrics."""
325
+ assert self.total_files.item() > 0, "No files dumped for FAD computation!" # type: ignore
326
+ fad_score = self._local_compute_frechet_audio_distance()
327
+ logger.warning(f"FAD score = {fad_score}")
328
+ fad_score = flashy.distrib.broadcast_object(fad_score, src=0)
329
+ return fad_score
audiocraft/metrics/kld.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import contextlib
8
+ from functools import partial
9
+ import logging
10
+ import os
11
+ import typing as tp
12
+
13
+ import torch
14
+ import torchmetrics
15
+
16
+ from ..data.audio_utils import convert_audio
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class _patch_passt_stft:
23
+ """Decorator to patch torch.stft in PaSST."""
24
+ def __init__(self):
25
+ self.old_stft = torch.stft
26
+
27
+ def __enter__(self):
28
+ # return_complex is a mandatory parameter in latest torch versions
29
+ # torch is throwing RuntimeErrors when not set
30
+ torch.stft = partial(torch.stft, return_complex=False)
31
+
32
+ def __exit__(self, *exc):
33
+ torch.stft = self.old_stft
34
+
35
+
36
+ def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
37
+ """Computes the elementwise KL-Divergence loss between probability distributions
38
+ from generated samples and target samples.
39
+
40
+ Args:
41
+ pred_probs (torch.Tensor): Probabilities for each label obtained
42
+ from a classifier on generated audio. Expected shape is [B, num_classes].
43
+ target_probs (torch.Tensor): Probabilities for each label obtained
44
+ from a classifier on target audio. Expected shape is [B, num_classes].
45
+ epsilon (float): Epsilon value.
46
+ Returns:
47
+ kld (torch.Tensor): KLD loss between each generated sample and target pair.
48
+ """
49
+ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
50
+ return kl_div.sum(-1)
51
+
52
+
53
+ class KLDivergenceMetric(torchmetrics.Metric):
54
+ """Base implementation for KL Divergence metric.
55
+
56
+ The KL divergence is measured between probability distributions
57
+ of class predictions returned by a pre-trained audio classification model.
58
+ When the KL-divergence is low, the generated audio is expected to
59
+ have similar acoustic characteristics as the reference audio,
60
+ according to the classifier.
61
+ """
62
+ def __init__(self):
63
+ super().__init__()
64
+ self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
65
+ self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
66
+ self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
67
+ self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
68
+
69
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
70
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
71
+ """Get model output given provided input tensor.
72
+
73
+ Args:
74
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
75
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
76
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
77
+ Returns:
78
+ probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
79
+ """
80
+ raise NotImplementedError("implement method to extract label distributions from the model.")
81
+
82
+ def update(self, preds: torch.Tensor, targets: torch.Tensor,
83
+ sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
84
+ """Calculates running KL-Divergence loss between batches of audio
85
+ preds (generated) and target (ground-truth)
86
+ Args:
87
+ preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
88
+ targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
89
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
90
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
91
+ """
92
+ assert preds.shape == targets.shape
93
+ assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
94
+ preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
95
+ targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
96
+ if preds_probs is not None and targets_probs is not None:
97
+ assert preds_probs.shape == targets_probs.shape
98
+ kld_scores = kl_divergence(preds_probs, targets_probs)
99
+ assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
100
+ self.kld_pq_sum += torch.sum(kld_scores)
101
+ kld_qp_scores = kl_divergence(targets_probs, preds_probs)
102
+ self.kld_qp_sum += torch.sum(kld_qp_scores)
103
+ self.weight += torch.tensor(kld_scores.size(0))
104
+
105
+ def compute(self) -> dict:
106
+ """Computes KL-Divergence across all evaluated pred/target pairs."""
107
+ weight: float = float(self.weight.item()) # type: ignore
108
+ assert weight > 0, "Unable to compute with total number of comparisons <= 0"
109
+ logger.info(f"Computing KL divergence on a total of {weight} samples")
110
+ kld_pq = self.kld_pq_sum.item() / weight # type: ignore
111
+ kld_qp = self.kld_qp_sum.item() / weight # type: ignore
112
+ kld_both = kld_pq + kld_qp
113
+ return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
114
+
115
+
116
+ class PasstKLDivergenceMetric(KLDivergenceMetric):
117
+ """KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
118
+
119
+ From: PaSST: Efficient Training of Audio Transformers with Patchout
120
+ Paper: https://arxiv.org/abs/2110.05069
121
+ Implementation: https://github.com/kkoutini/PaSST
122
+
123
+ Follow instructions from the github repo:
124
+ ```
125
+ pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
126
+ ```
127
+
128
+ Args:
129
+ pretrained_length (float, optional): Audio duration used for the pretrained model.
130
+ """
131
+ def __init__(self, pretrained_length: tp.Optional[float] = None):
132
+ super().__init__()
133
+ self._initialize_model(pretrained_length)
134
+
135
+ def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
136
+ """Initialize underlying PaSST audio classifier."""
137
+ model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
138
+ self.min_input_frames = min_frames
139
+ self.max_input_frames = max_frames
140
+ self.model_sample_rate = sr
141
+ self.model = model
142
+ self.model.eval()
143
+ self.model.to(self.device)
144
+
145
+ def _load_base_model(self, pretrained_length: tp.Optional[float]):
146
+ """Load pretrained model from PaSST."""
147
+ try:
148
+ if pretrained_length == 30:
149
+ from hear21passt.base30sec import get_basic_model # type: ignore
150
+ max_duration = 30
151
+ elif pretrained_length == 20:
152
+ from hear21passt.base20sec import get_basic_model # type: ignore
153
+ max_duration = 20
154
+ else:
155
+ from hear21passt.base import get_basic_model # type: ignore
156
+ # Original PASST was trained on AudioSet with 10s-long audio samples
157
+ max_duration = 10
158
+ min_duration = 0.15
159
+ min_duration = 0.15
160
+ except ModuleNotFoundError:
161
+ raise ModuleNotFoundError(
162
+ "Please install hear21passt to compute KL divergence: ",
163
+ "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
164
+ )
165
+ model_sample_rate = 32_000
166
+ max_input_frames = int(max_duration * model_sample_rate)
167
+ min_input_frames = int(min_duration * model_sample_rate)
168
+ with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
169
+ model = get_basic_model(mode='logits')
170
+ return model, model_sample_rate, max_input_frames, min_input_frames
171
+
172
+ def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
173
+ """Process audio to feed to the pretrained model."""
174
+ wav = wav.unsqueeze(0)
175
+ wav = wav[..., :wav_len]
176
+ wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
177
+ wav = wav.squeeze(0)
178
+ # we don't pad but return a list of audio segments as this otherwise affects the KLD computation
179
+ segments = torch.split(wav, self.max_input_frames, dim=-1)
180
+ valid_segments = []
181
+ for s in segments:
182
+ # ignoring too small segments that are breaking the model inference
183
+ if s.size(-1) > self.min_input_frames:
184
+ valid_segments.append(s)
185
+ return [s[None] for s in valid_segments]
186
+
187
+ def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
188
+ """Run the pretrained model and get the predictions."""
189
+ assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
190
+ wav = wav.mean(dim=1)
191
+ # PaSST is printing a lot of garbage that we are not interested in
192
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
193
+ with torch.no_grad(), _patch_passt_stft():
194
+ logits = self.model(wav.to(self.device))
195
+ probs = torch.softmax(logits, dim=-1)
196
+ return probs
197
+
198
+ def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
199
+ sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
200
+ """Get model output given provided input tensor.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input audio tensor of shape [B, C, T].
204
+ sizes (torch.Tensor): Actual audio sample length, of shape [B].
205
+ sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
206
+ Returns:
207
+ probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
208
+ """
209
+ all_probs: tp.List[torch.Tensor] = []
210
+ for i, wav in enumerate(x):
211
+ sample_rate = int(sample_rates[i].item())
212
+ wav_len = int(sizes[i].item())
213
+ wav_segments = self._process_audio(wav, sample_rate, wav_len)
214
+ for segment in wav_segments:
215
+ probs = self._get_model_preds(segment).mean(dim=0)
216
+ all_probs.append(probs)
217
+ if len(all_probs) > 0:
218
+ return torch.stack(all_probs, dim=0)
219
+ else:
220
+ return None
audiocraft/metrics/rvm.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+ import torch
9
+ from torch import nn
10
+ import torchaudio
11
+
12
+
13
+ def db_to_scale(volume: tp.Union[float, torch.Tensor]):
14
+ return 10 ** (volume / 20)
15
+
16
+
17
+ def scale_to_db(scale: torch.Tensor, min_volume: float = -120):
18
+ min_scale = db_to_scale(min_volume)
19
+ return 20 * torch.log10(scale.clamp(min=min_scale))
20
+
21
+
22
+ class RelativeVolumeMel(nn.Module):
23
+ """Relative volume melspectrogram measure.
24
+
25
+ Computes a measure of distance over two mel spectrogram that is interpretable in terms
26
+ of decibels. Given `x_ref` and `x_est` two waveforms of shape `[*, T]`, it will
27
+ first renormalize both by the ground truth of `x_ref`.
28
+
29
+ ..Warning:: This class returns the volume of the distortion at the spectrogram level,
30
+ e.g. low negative values reflects lower distortion levels. For a SNR (like reported
31
+ in the MultiBandDiffusion paper), just take `-rvm`.
32
+
33
+ Then it computes the mel spectrogram `z_ref` and `z_est` and compute volume of the difference
34
+ relative to the volume of `z_ref` for each time-frequency bin. It further adds some limits, e.g.
35
+ clamping the values between -25 and 25 dB (controlled by `min_relative_volume` and `max_relative_volume`)
36
+ with the goal of avoiding the loss being dominated by parts where the reference is almost silent.
37
+ Indeed, volumes in dB can take unbounded values both towards -oo and +oo, which can make the final
38
+ average metric harder to interpret. Besides, anything below -30 dB of attenuation would sound extremely
39
+ good (for a neural network output, although sound engineers typically aim for much lower attenuations).
40
+ Similarly, anything above +30 dB would just be completely missing the target, and there is no point
41
+ in measuring by exactly how much it missed it. -25, 25 is a more conservative range, but also more
42
+ in line with what neural nets currently can achieve.
43
+
44
+ For instance, a Relative Volume Mel (RVM) score of -10 dB means that on average, the delta between
45
+ the target and reference mel-spec is 10 dB lower than the reference mel-spec value.
46
+
47
+ The metric can be aggregated over a given frequency band in order have different insights for
48
+ different region of the spectrum. `num_aggregated_bands` controls the number of bands.
49
+
50
+ ..Warning:: While this function is optimized for interpretability, nothing was done to ensure it
51
+ is numerically stable when computing its gradient. We thus advise against using it as a training loss.
52
+
53
+ Args:
54
+ sample_rate (int): Sample rate of the input audio.
55
+ n_mels (int): Number of mel bands to use.
56
+ n_fft (int): Number of frequency bins for the STFT.
57
+ hop_length (int): Hop length of the STFT and the mel-spectrogram.
58
+ min_relative_volume (float): The error `z_ref - z_est` volume is given relative to
59
+ the volume of `z_ref`. If error is smaller than -25 dB of `z_ref`, then it is clamped.
60
+ max_relative_volume (float): Same as `min_relative_volume` but clamping if the error is larger than that.
61
+ max_initial_gain (float): When rescaling the audio at the very beginning, we will limit the gain
62
+ to that amount, to avoid rescaling near silence. Given in dB.
63
+ min_activity_volume (float): When computing the reference level from `z_ref`, will clamp low volume
64
+ bins to that amount. This is effectively our "zero" level for the reference mel-spectrogram,
65
+ and anything below that will be considered equally.
66
+ num_aggregated_bands (int): Number of bands to keep when computing the average RVM value.
67
+ For instance, a value of 3 would give 3 scores, roughly for low, mid and high freqs.
68
+ """
69
+ def __init__(self, sample_rate: int = 24000, n_mels: int = 80, n_fft: int = 512,
70
+ hop_length: int = 128, min_relative_volume: float = -25,
71
+ max_relative_volume: float = 25, max_initial_gain: float = 25,
72
+ min_activity_volume: float = -25,
73
+ num_aggregated_bands: int = 4) -> None:
74
+ super().__init__()
75
+ self.melspec = torchaudio.transforms.MelSpectrogram(
76
+ n_mels=n_mels, n_fft=n_fft, hop_length=hop_length,
77
+ normalized=True, sample_rate=sample_rate, power=2)
78
+ self.min_relative_volume = min_relative_volume
79
+ self.max_relative_volume = max_relative_volume
80
+ self.max_initial_gain = max_initial_gain
81
+ self.min_activity_volume = min_activity_volume
82
+ self.num_aggregated_bands = num_aggregated_bands
83
+
84
+ def forward(self, estimate: torch.Tensor, ground_truth: torch.Tensor) -> tp.Dict[str, torch.Tensor]:
85
+ """Compute RVM metric between estimate and reference samples.
86
+
87
+ Args:
88
+ estimate (torch.Tensor): Estimate sample.
89
+ ground_truth (torch.Tensor): Reference sample.
90
+
91
+ Returns:
92
+ dict[str, torch.Tensor]: Metrics with keys `rvm` for the overall average, and `rvm_{k}`
93
+ for the RVM over the k-th band (k=0..num_aggregated_bands - 1).
94
+ """
95
+ min_scale = db_to_scale(-self.max_initial_gain)
96
+ std = ground_truth.pow(2).mean().sqrt().clamp(min=min_scale)
97
+ z_gt = self.melspec(ground_truth / std).sqrt()
98
+ z_est = self.melspec(estimate / std).sqrt()
99
+
100
+ delta = z_gt - z_est
101
+ ref_db = scale_to_db(z_gt, self.min_activity_volume)
102
+ delta_db = scale_to_db(delta.abs(), min_volume=-120)
103
+ relative_db = (delta_db - ref_db).clamp(self.min_relative_volume, self.max_relative_volume)
104
+ dims = list(range(relative_db.dim()))
105
+ dims.remove(dims[-2])
106
+ losses_per_band = relative_db.mean(dim=dims)
107
+ aggregated = [chunk.mean() for chunk in losses_per_band.chunk(self.num_aggregated_bands, dim=0)]
108
+ metrics = {f'rvm_{index}': value for index, value in enumerate(aggregated)}
109
+ metrics['rvm'] = losses_per_band.mean()
110
+ return metrics
audiocraft/metrics/visqol.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import json
9
+ import logging
10
+ from pathlib import Path
11
+ import tempfile
12
+ import typing as tp
13
+ import subprocess
14
+ import shutil
15
+
16
+ import torch
17
+ import torchaudio
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ViSQOL:
23
+ """ViSQOL wrapper to run ViSQOL from Python using a pre-installed binary.
24
+
25
+ To learn more about ViSQOL and how to build ViSQOL binary using bazel, please refer to the
26
+ instructions available in the open source repository: https://github.com/google/visqol
27
+
28
+ ViSQOL is capable of running in two modes:
29
+
30
+ Audio Mode:
31
+ When running in audio mode, input signals must have a 48kHz sample rate. Input should be resampled to 48kHz.
32
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
33
+ Audio mode uses support vector regression, with the maximum range at ~4.75.
34
+
35
+ Speech Mode:
36
+ When running in speech mode, ViSQOL uses a wideband model. It therefore expects input sample rates of 16kHz.
37
+ Input should be resampled to 16kHz.
38
+ As part of the speech mode processing, a root mean square implementation for voice activity detection
39
+ is performed on the reference signal to determine what parts of the signal have voice activity and
40
+ should therefore be included in the comparison. The signal is normalized before performing the voice
41
+ activity detection.
42
+ Input signals can be multi-channel, but they will be down-mixed to mono for performing the comparison.
43
+ Speech mode is scaled to have a maximum MOS of 5.0 to match previous version behavior.
44
+
45
+ For more details, check the guidelines: https://github.com/google/visqol#general-guidelines-for-input
46
+
47
+ Args:
48
+ visqol_bin (str): Path to the ViSQOL binary.
49
+ mode (str): ViSQOL computation mode, expecting "audio" or "speech".
50
+ model (str): Name of the model to use for similarity to quality model.
51
+ debug (bool): Whether to also get debug metrics from ViSQOL or not.
52
+ """
53
+ SAMPLE_RATES_MODES = {"audio": 48_000, "speech": 16_000}
54
+ ALLOWED_SAMPLE_RATES = frozenset(SAMPLE_RATES_MODES.values())
55
+
56
+ def __init__(self, bin: tp.Union[Path, str], mode: str = "audio",
57
+ model: str = "libsvm_nu_svr_model.txt", debug: bool = False):
58
+ assert bin is not None and Path(bin).exists(), f"Could not find ViSQOL binary in specified path: {bin}"
59
+ self.visqol_bin = str(bin)
60
+ self.visqol_mode = mode
61
+ self.target_sr = self._get_target_sr(self.visqol_mode)
62
+ self.model = model
63
+ self.debug = debug
64
+ assert Path(self.visqol_model).exists(), \
65
+ f"Could not find the specified model in ViSQOL install: {self.visqol_model}"
66
+
67
+ def _get_target_sr(self, mode: str) -> int:
68
+ # returns target sampling rate for the corresponding ViSQOL mode.
69
+ if mode not in ViSQOL.SAMPLE_RATES_MODES:
70
+ raise ValueError(
71
+ f"Unsupported mode! Allowed are: {', '.join(ViSQOL.SAMPLE_RATES_MODES.keys())}"
72
+ )
73
+ return ViSQOL.SAMPLE_RATES_MODES[mode]
74
+
75
+ def _prepare_files(
76
+ self, ref_sig: torch.Tensor, deg_sig: torch.Tensor, sr: int, target_sr: int, pad_with_silence: bool = False
77
+ ):
78
+ # prepare files for ViSQOL evaluation.
79
+ assert target_sr in ViSQOL.ALLOWED_SAMPLE_RATES
80
+ assert len(ref_sig) == len(deg_sig), (
81
+ "Expects same number of ref and degraded inputs",
82
+ f" but ref len {len(ref_sig)} != deg len {len(deg_sig)}"
83
+ )
84
+ # resample audio if needed
85
+ if sr != target_sr:
86
+ transform = torchaudio.transforms.Resample(sr, target_sr)
87
+ pad = int(0.5 * target_sr)
88
+ rs_ref = []
89
+ rs_deg = []
90
+ for i in range(len(ref_sig)):
91
+ rs_ref_i = transform(ref_sig[i])
92
+ rs_deg_i = transform(deg_sig[i])
93
+ if pad_with_silence:
94
+ rs_ref_i = torch.nn.functional.pad(rs_ref_i, (pad, pad), mode='constant', value=0)
95
+ rs_deg_i = torch.nn.functional.pad(rs_deg_i, (pad, pad), mode='constant', value=0)
96
+ rs_ref.append(rs_ref_i)
97
+ rs_deg.append(rs_deg_i)
98
+ ref_sig = torch.stack(rs_ref)
99
+ deg_sig = torch.stack(rs_deg)
100
+ # save audio chunks to tmp dir and create csv
101
+ tmp_dir = Path(tempfile.mkdtemp())
102
+ try:
103
+ tmp_input_csv_path = tmp_dir / "input.csv"
104
+ tmp_results_csv_path = tmp_dir / "results.csv"
105
+ tmp_debug_json_path = tmp_dir / "debug.json"
106
+ with open(tmp_input_csv_path, "w") as csv_file:
107
+ csv_writer = csv.writer(csv_file)
108
+ csv_writer.writerow(["reference", "degraded"])
109
+ for i in range(len(ref_sig)):
110
+ tmp_ref_filename = tmp_dir / f"ref_{i}.wav"
111
+ tmp_deg_filename = tmp_dir / f"deg_{i}.wav"
112
+ torchaudio.save(
113
+ tmp_ref_filename,
114
+ torch.clamp(ref_sig[i], min=-0.99, max=0.99),
115
+ sample_rate=target_sr,
116
+ bits_per_sample=16,
117
+ encoding="PCM_S"
118
+ )
119
+ torchaudio.save(
120
+ tmp_deg_filename,
121
+ torch.clamp(deg_sig[i], min=-0.99, max=0.99),
122
+ sample_rate=target_sr,
123
+ bits_per_sample=16,
124
+ encoding="PCM_S"
125
+ )
126
+ csv_writer.writerow([str(tmp_ref_filename), str(tmp_deg_filename)])
127
+ return tmp_dir, tmp_input_csv_path, tmp_results_csv_path, tmp_debug_json_path
128
+ except Exception as e:
129
+ logger.error("Exception occurred when preparing files for ViSQOL: %s", e)
130
+ return tmp_dir, None, None, None
131
+
132
+ def _flush_files(self, tmp_dir: tp.Union[Path, str]):
133
+ # flush tmp files used to compute ViSQOL.
134
+ shutil.rmtree(str(tmp_dir))
135
+
136
+ def _collect_moslqo_score(self, results_csv_path: tp.Union[Path, str]) -> float:
137
+ # collect results for each evaluated pair and return averaged moslqo score.
138
+ with open(results_csv_path, "r") as csv_file:
139
+ reader = csv.DictReader(csv_file)
140
+ moslqo_scores = [float(row["moslqo"]) for row in reader]
141
+ if len(moslqo_scores) > 0:
142
+ return sum(moslqo_scores) / len(moslqo_scores)
143
+ else:
144
+ return 0.0
145
+
146
+ def _collect_debug_data(self, debug_json_path: tp.Union[Path, str]) -> dict:
147
+ # collect debug data for the visqol inference.
148
+ with open(debug_json_path, "r") as f:
149
+ data = json.load(f)
150
+ return data
151
+
152
+ @property
153
+ def visqol_model(self):
154
+ return f'{self.visqol_bin}/model/{self.model}'
155
+
156
+ def _run_visqol(
157
+ self,
158
+ input_csv_path: tp.Union[Path, str],
159
+ results_csv_path: tp.Union[Path, str],
160
+ debug_csv_path: tp.Optional[tp.Union[Path, str]],
161
+ ):
162
+ input_csv_path = str(input_csv_path)
163
+ results_csv_path = str(results_csv_path)
164
+ debug_csv_path = str(debug_csv_path)
165
+ cmd = [
166
+ f'{self.visqol_bin}/bazel-bin/visqol',
167
+ '--batch_input_csv', f'{input_csv_path}',
168
+ '--results_csv', f'{results_csv_path}'
169
+ ]
170
+ if debug_csv_path is not None:
171
+ cmd += ['--output_debug', f'{debug_csv_path}']
172
+ if self.visqol_mode == "speech":
173
+ cmd += ['--use_speech_mode']
174
+ cmd += ['--similarity_to_quality_model', f'{self.visqol_model}']
175
+ result = subprocess.run(cmd, capture_output=True)
176
+ if result.returncode:
177
+ logger.error("Error with visqol: \n %s \n %s", result.stdout.decode(), result.stderr.decode())
178
+ raise RuntimeError("Error while executing visqol")
179
+ result.check_returncode()
180
+
181
+ def __call__(
182
+ self,
183
+ ref_sig: torch.Tensor,
184
+ deg_sig: torch.Tensor,
185
+ sr: int,
186
+ pad_with_silence: bool = False,
187
+ ):
188
+ """Calculate the ViSQOL metric for a pair of audio signals at a given sample rate.
189
+ Args:
190
+ ref_sig (torch.Tensor): Reference signals as [B, C, T].
191
+ deg_sig (torch.Tensor): Degraded signals as [B, C, T].
192
+ sr (int): Sample rate of the two audio signals.
193
+ pad_with_silence (bool): Whether to pad the file with silences as recommended
194
+ in visqol guidelines (see: https://github.com/google/visqol#general-guidelines-for-input).
195
+ Returns:
196
+ float: The ViSQOL score or mean score for the batch.
197
+ """
198
+ logger.debug(f"Calculating visqol with mode={self.visqol_mode} on {len(ref_sig)} samples")
199
+ tmp_dir, input_csv, results_csv, debug_json = self._prepare_files(
200
+ ref_sig, deg_sig, sr, self.target_sr, pad_with_silence
201
+ )
202
+ try:
203
+ if input_csv and results_csv:
204
+ self._run_visqol(
205
+ input_csv,
206
+ results_csv,
207
+ debug_json if self.debug else None,
208
+ )
209
+ mosqol = self._collect_moslqo_score(results_csv)
210
+ return mosqol
211
+ else:
212
+ raise RuntimeError("Something unexpected happened when running VISQOL!")
213
+ except Exception as e:
214
+ logger.error("Exception occurred when running ViSQOL: %s", e)
215
+ finally:
216
+ self._flush_files(tmp_dir)