File size: 10,211 Bytes
5325fcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from functools import partial
import logging
import os
import typing as tp
import torch
import torchmetrics
from ..data.audio_utils import convert_audio
logger = logging.getLogger(__name__)
class _patch_passt_stft:
"""Decorator to patch torch.stft in PaSST."""
def __init__(self):
self.old_stft = torch.stft
def __enter__(self):
# return_complex is a mandatory parameter in latest torch versions
# torch is throwing RuntimeErrors when not set
torch.stft = partial(torch.stft, return_complex=False)
def __exit__(self, *exc):
torch.stft = self.old_stft
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
"""Computes the elementwise KL-Divergence loss between probability distributions
from generated samples and target samples.
Args:
pred_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on generated audio. Expected shape is [B, num_classes].
target_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on target audio. Expected shape is [B, num_classes].
epsilon (float): Epsilon value.
Returns:
kld (torch.Tensor): KLD loss between each generated sample and target pair.
"""
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
return kl_div.sum(-1)
class KLDivergenceMetric(torchmetrics.Metric):
"""Base implementation for KL Divergence metric.
The KL divergence is measured between probability distributions
of class predictions returned by a pre-trained audio classification model.
When the KL-divergence is low, the generated audio is expected to
have similar acoustic characteristics as the reference audio,
according to the classifier.
"""
def __init__(self):
super().__init__()
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
"""
raise NotImplementedError("implement method to extract label distributions from the model.")
def update(self, preds: torch.Tensor, targets: torch.Tensor,
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Calculates running KL-Divergence loss between batches of audio
preds (generated) and target (ground-truth)
Args:
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
"""
assert preds.shape == targets.shape
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
if preds_probs is not None and targets_probs is not None:
assert preds_probs.shape == targets_probs.shape
kld_scores = kl_divergence(preds_probs, targets_probs)
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
self.kld_pq_sum += torch.sum(kld_scores)
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
self.kld_qp_sum += torch.sum(kld_qp_scores)
self.weight += torch.tensor(kld_scores.size(0))
def compute(self) -> dict:
"""Computes KL-Divergence across all evaluated pred/target pairs."""
weight: float = float(self.weight.item()) # type: ignore
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
logger.info(f"Computing KL divergence on a total of {weight} samples")
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
kld_both = kld_pq + kld_qp
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
class PasstKLDivergenceMetric(KLDivergenceMetric):
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
From: PaSST: Efficient Training of Audio Transformers with Patchout
Paper: https://arxiv.org/abs/2110.05069
Implementation: https://github.com/kkoutini/PaSST
Follow instructions from the github repo:
```
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
```
Args:
pretrained_length (float, optional): Audio duration used for the pretrained model.
"""
def __init__(self, pretrained_length: tp.Optional[float] = None):
super().__init__()
self._initialize_model(pretrained_length)
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
"""Initialize underlying PaSST audio classifier."""
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
self.min_input_frames = min_frames
self.max_input_frames = max_frames
self.model_sample_rate = sr
self.model = model
self.model.eval()
self.model.to(self.device)
def _load_base_model(self, pretrained_length: tp.Optional[float]):
"""Load pretrained model from PaSST."""
try:
if pretrained_length == 30:
from hear21passt.base30sec import get_basic_model # type: ignore
max_duration = 30
elif pretrained_length == 20:
from hear21passt.base20sec import get_basic_model # type: ignore
max_duration = 20
else:
from hear21passt.base import get_basic_model # type: ignore
# Original PASST was trained on AudioSet with 10s-long audio samples
max_duration = 10
min_duration = 0.15
min_duration = 0.15
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install hear21passt to compute KL divergence: ",
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
)
model_sample_rate = 32_000
max_input_frames = int(max_duration * model_sample_rate)
min_input_frames = int(min_duration * model_sample_rate)
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
model = get_basic_model(mode='logits')
return model, model_sample_rate, max_input_frames, min_input_frames
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
"""Process audio to feed to the pretrained model."""
wav = wav.unsqueeze(0)
wav = wav[..., :wav_len]
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
wav = wav.squeeze(0)
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
segments = torch.split(wav, self.max_input_frames, dim=-1)
valid_segments = []
for s in segments:
# ignoring too small segments that are breaking the model inference
if s.size(-1) > self.min_input_frames:
valid_segments.append(s)
return [s[None] for s in valid_segments]
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
"""Run the pretrained model and get the predictions."""
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
wav = wav.mean(dim=1)
# PaSST is printing a lot of garbage that we are not interested in
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = self.model(wav.to(self.device))
probs = torch.softmax(logits, dim=-1)
return probs
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
"""
all_probs: tp.List[torch.Tensor] = []
for i, wav in enumerate(x):
sample_rate = int(sample_rates[i].item())
wav_len = int(sizes[i].item())
wav_segments = self._process_audio(wav, sample_rate, wav_len)
for segment in wav_segments:
probs = self._get_model_preds(segment).mean(dim=0)
all_probs.append(probs)
if len(all_probs) > 0:
return torch.stack(all_probs, dim=0)
else:
return None
|