Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
from functools import partial | |
import logging | |
import os | |
import typing as tp | |
import torch | |
import torchmetrics | |
from ..data.audio_utils import convert_audio | |
logger = logging.getLogger(__name__) | |
class _patch_passt_stft: | |
"""Decorator to patch torch.stft in PaSST.""" | |
def __init__(self): | |
self.old_stft = torch.stft | |
def __enter__(self): | |
# return_complex is a mandatory parameter in latest torch versions | |
# torch is throwing RuntimeErrors when not set | |
torch.stft = partial(torch.stft, return_complex=False) | |
def __exit__(self, *exc): | |
torch.stft = self.old_stft | |
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: | |
"""Computes the elementwise KL-Divergence loss between probability distributions | |
from generated samples and target samples. | |
Args: | |
pred_probs (torch.Tensor): Probabilities for each label obtained | |
from a classifier on generated audio. Expected shape is [B, num_classes]. | |
target_probs (torch.Tensor): Probabilities for each label obtained | |
from a classifier on target audio. Expected shape is [B, num_classes]. | |
epsilon (float): Epsilon value. | |
Returns: | |
kld (torch.Tensor): KLD loss between each generated sample and target pair. | |
""" | |
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") | |
return kl_div.sum(-1) | |
class KLDivergenceMetric(torchmetrics.Metric): | |
"""Base implementation for KL Divergence metric. | |
The KL divergence is measured between probability distributions | |
of class predictions returned by a pre-trained audio classification model. | |
When the KL-divergence is low, the generated audio is expected to | |
have similar acoustic characteristics as the reference audio, | |
according to the classifier. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") | |
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") | |
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") | |
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") | |
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, | |
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: | |
"""Get model output given provided input tensor. | |
Args: | |
x (torch.Tensor): Input audio tensor of shape [B, C, T]. | |
sizes (torch.Tensor): Actual audio sample length, of shape [B]. | |
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. | |
Returns: | |
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. | |
""" | |
raise NotImplementedError("implement method to extract label distributions from the model.") | |
def update(self, preds: torch.Tensor, targets: torch.Tensor, | |
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: | |
"""Calculates running KL-Divergence loss between batches of audio | |
preds (generated) and target (ground-truth) | |
Args: | |
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. | |
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. | |
sizes (torch.Tensor): Actual audio sample length, of shape [B]. | |
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. | |
""" | |
assert preds.shape == targets.shape | |
assert preds.size(0) > 0, "Cannot update the loss with empty tensors" | |
preds_probs = self._get_label_distribution(preds, sizes, sample_rates) | |
targets_probs = self._get_label_distribution(targets, sizes, sample_rates) | |
if preds_probs is not None and targets_probs is not None: | |
assert preds_probs.shape == targets_probs.shape | |
kld_scores = kl_divergence(preds_probs, targets_probs) | |
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" | |
self.kld_pq_sum += torch.sum(kld_scores) | |
kld_qp_scores = kl_divergence(targets_probs, preds_probs) | |
self.kld_qp_sum += torch.sum(kld_qp_scores) | |
self.weight += torch.tensor(kld_scores.size(0)) | |
def compute(self) -> dict: | |
"""Computes KL-Divergence across all evaluated pred/target pairs.""" | |
weight: float = float(self.weight.item()) # type: ignore | |
assert weight > 0, "Unable to compute with total number of comparisons <= 0" | |
logger.info(f"Computing KL divergence on a total of {weight} samples") | |
kld_pq = self.kld_pq_sum.item() / weight # type: ignore | |
kld_qp = self.kld_qp_sum.item() / weight # type: ignore | |
kld_both = kld_pq + kld_qp | |
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both} | |
class PasstKLDivergenceMetric(KLDivergenceMetric): | |
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet. | |
From: PaSST: Efficient Training of Audio Transformers with Patchout | |
Paper: https://arxiv.org/abs/2110.05069 | |
Implementation: https://github.com/kkoutini/PaSST | |
Follow instructions from the github repo: | |
``` | |
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' | |
``` | |
Args: | |
pretrained_length (float, optional): Audio duration used for the pretrained model. | |
""" | |
def __init__(self, pretrained_length: tp.Optional[float] = None): | |
super().__init__() | |
self._initialize_model(pretrained_length) | |
def _initialize_model(self, pretrained_length: tp.Optional[float] = None): | |
"""Initialize underlying PaSST audio classifier.""" | |
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) | |
self.min_input_frames = min_frames | |
self.max_input_frames = max_frames | |
self.model_sample_rate = sr | |
self.model = model | |
self.model.eval() | |
self.model.to(self.device) | |
def _load_base_model(self, pretrained_length: tp.Optional[float]): | |
"""Load pretrained model from PaSST.""" | |
try: | |
if pretrained_length == 30: | |
from hear21passt.base30sec import get_basic_model # type: ignore | |
max_duration = 30 | |
elif pretrained_length == 20: | |
from hear21passt.base20sec import get_basic_model # type: ignore | |
max_duration = 20 | |
else: | |
from hear21passt.base import get_basic_model # type: ignore | |
# Original PASST was trained on AudioSet with 10s-long audio samples | |
max_duration = 10 | |
min_duration = 0.15 | |
min_duration = 0.15 | |
except ModuleNotFoundError: | |
raise ModuleNotFoundError( | |
"Please install hear21passt to compute KL divergence: ", | |
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" | |
) | |
model_sample_rate = 32_000 | |
max_input_frames = int(max_duration * model_sample_rate) | |
min_input_frames = int(min_duration * model_sample_rate) | |
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): | |
model = get_basic_model(mode='logits') | |
return model, model_sample_rate, max_input_frames, min_input_frames | |
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.Optional[torch.Tensor]: | |
wav = wav.unsqueeze(0) | |
wav = wav[..., :wav_len] | |
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) | |
wav = wav.squeeze(0) | |
# create chunks of audio to match the classifier processing length | |
segments = torch.split(wav, self.max_input_frames, dim=-1) | |
valid_segments = [] | |
for s in segments: | |
if s.size(-1) > self.min_input_frames: | |
s = torch.nn.functional.pad(s, (0, self.max_input_frames - s.shape[-1])) | |
valid_segments.append(s) | |
if len(valid_segments) > 0: | |
return torch.stack(valid_segments, dim=0) | |
else: | |
return None | |
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, | |
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: | |
"""Get model output given provided input tensor. | |
Args: | |
x (torch.Tensor): Input audio tensor of shape [B, C, T]. | |
sizes (torch.Tensor): Actual audio sample length, of shape [B]. | |
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. | |
Returns: | |
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. | |
""" | |
all_probs: tp.List[torch.Tensor] = [] | |
for i, wav in enumerate(x): | |
sample_rate = int(sample_rates[i].item()) | |
wav_len = int(sizes[i].item()) | |
wav = self._process_audio(wav, sample_rate, wav_len) | |
if wav is not None: | |
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" | |
wav = wav.mean(dim=1) | |
# PaSST is printing a lot of infos that we are not interested in | |
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): | |
with torch.no_grad(), _patch_passt_stft(): | |
logits = self.model(wav.to(self.device)) | |
probs = torch.softmax(logits, dim=-1) | |
probs = probs.mean(dim=0) | |
all_probs.append(probs) | |
if len(all_probs) > 0: | |
return torch.stack(all_probs, dim=0) | |
else: | |
return None | |