# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import contextlib from functools import partial import logging import os import typing as tp import torch import torchmetrics from ..data.audio_utils import convert_audio logger = logging.getLogger(__name__) class _patch_passt_stft: """Decorator to patch torch.stft in PaSST.""" def __init__(self): self.old_stft = torch.stft def __enter__(self): # return_complex is a mandatory parameter in latest torch versions # torch is throwing RuntimeErrors when not set torch.stft = partial(torch.stft, return_complex=False) def __exit__(self, *exc): torch.stft = self.old_stft def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor: """Computes the elementwise KL-Divergence loss between probability distributions from generated samples and target samples. Args: pred_probs (torch.Tensor): Probabilities for each label obtained from a classifier on generated audio. Expected shape is [B, num_classes]. target_probs (torch.Tensor): Probabilities for each label obtained from a classifier on target audio. Expected shape is [B, num_classes]. epsilon (float): Epsilon value. Returns: kld (torch.Tensor): KLD loss between each generated sample and target pair. """ kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none") return kl_div.sum(-1) class KLDivergenceMetric(torchmetrics.Metric): """Base implementation for KL Divergence metric. The KL divergence is measured between probability distributions of class predictions returned by a pre-trained audio classification model. When the KL-divergence is low, the generated audio is expected to have similar acoustic characteristics as the reference audio, according to the classifier. """ def __init__(self): super().__init__() self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum") self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum") def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes]. """ raise NotImplementedError("implement method to extract label distributions from the model.") def update(self, preds: torch.Tensor, targets: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: """Calculates running KL-Divergence loss between batches of audio preds (generated) and target (ground-truth) Args: preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T]. targets (torch.Tensor): Target samples to compare against, of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. """ assert preds.shape == targets.shape assert preds.size(0) > 0, "Cannot update the loss with empty tensors" preds_probs = self._get_label_distribution(preds, sizes, sample_rates) targets_probs = self._get_label_distribution(targets, sizes, sample_rates) if preds_probs is not None and targets_probs is not None: assert preds_probs.shape == targets_probs.shape kld_scores = kl_divergence(preds_probs, targets_probs) assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!" self.kld_pq_sum += torch.sum(kld_scores) kld_qp_scores = kl_divergence(targets_probs, preds_probs) self.kld_qp_sum += torch.sum(kld_qp_scores) self.weight += torch.tensor(kld_scores.size(0)) def compute(self) -> dict: """Computes KL-Divergence across all evaluated pred/target pairs.""" weight: float = float(self.weight.item()) # type: ignore assert weight > 0, "Unable to compute with total number of comparisons <= 0" logger.info(f"Computing KL divergence on a total of {weight} samples") kld_pq = self.kld_pq_sum.item() / weight # type: ignore kld_qp = self.kld_qp_sum.item() / weight # type: ignore kld_both = kld_pq + kld_qp return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both} class PasstKLDivergenceMetric(KLDivergenceMetric): """KL-Divergence metric based on pre-trained PASST classifier on AudioSet. From: PaSST: Efficient Training of Audio Transformers with Patchout Paper: https://arxiv.org/abs/2110.05069 Implementation: https://github.com/kkoutini/PaSST Follow instructions from the github repo: ``` pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt' ``` Args: pretrained_length (float, optional): Audio duration used for the pretrained model. """ def __init__(self, pretrained_length: tp.Optional[float] = None): super().__init__() self._initialize_model(pretrained_length) def _initialize_model(self, pretrained_length: tp.Optional[float] = None): """Initialize underlying PaSST audio classifier.""" model, sr, max_frames, min_frames = self._load_base_model(pretrained_length) self.min_input_frames = min_frames self.max_input_frames = max_frames self.model_sample_rate = sr self.model = model self.model.eval() self.model.to(self.device) def _load_base_model(self, pretrained_length: tp.Optional[float]): """Load pretrained model from PaSST.""" try: if pretrained_length == 30: from hear21passt.base30sec import get_basic_model # type: ignore max_duration = 30 elif pretrained_length == 20: from hear21passt.base20sec import get_basic_model # type: ignore max_duration = 20 else: from hear21passt.base import get_basic_model # type: ignore # Original PASST was trained on AudioSet with 10s-long audio samples max_duration = 10 min_duration = 0.15 min_duration = 0.15 except ModuleNotFoundError: raise ModuleNotFoundError( "Please install hear21passt to compute KL divergence: ", "pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'" ) model_sample_rate = 32_000 max_input_frames = int(max_duration * model_sample_rate) min_input_frames = int(min_duration * model_sample_rate) with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): model = get_basic_model(mode='logits') return model, model_sample_rate, max_input_frames, min_input_frames def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.Optional[torch.Tensor]: wav = wav.unsqueeze(0) wav = wav[..., :wav_len] wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1) wav = wav.squeeze(0) # create chunks of audio to match the classifier processing length segments = torch.split(wav, self.max_input_frames, dim=-1) valid_segments = [] for s in segments: if s.size(-1) > self.min_input_frames: s = torch.nn.functional.pad(s, (0, self.max_input_frames - s.shape[-1])) valid_segments.append(s) if len(valid_segments) > 0: return torch.stack(valid_segments, dim=0) else: return None def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor, sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]: """Get model output given provided input tensor. Args: x (torch.Tensor): Input audio tensor of shape [B, C, T]. sizes (torch.Tensor): Actual audio sample length, of shape [B]. sample_rates (torch.Tensor): Actual audio sample rate, of shape [B]. Returns: probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes]. """ all_probs: tp.List[torch.Tensor] = [] for i, wav in enumerate(x): sample_rate = int(sample_rates[i].item()) wav_len = int(sizes[i].item()) wav = self._process_audio(wav, sample_rate, wav_len) if wav is not None: assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}" wav = wav.mean(dim=1) # PaSST is printing a lot of infos that we are not interested in with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f): with torch.no_grad(), _patch_passt_stft(): logits = self.model(wav.to(self.device)) probs = torch.softmax(logits, dim=-1) probs = probs.mean(dim=0) all_probs.append(probs) if len(all_probs) > 0: return torch.stack(all_probs, dim=0) else: return None