reach-vb's picture
reach-vb HF staff
Stereo demo update (#60)
5325fcc
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
from functools import partial
import logging
import os
import typing as tp
import torch
import torchmetrics
from ..data.audio_utils import convert_audio
logger = logging.getLogger(__name__)
class _patch_passt_stft:
"""Decorator to patch torch.stft in PaSST."""
def __init__(self):
self.old_stft = torch.stft
def __enter__(self):
# return_complex is a mandatory parameter in latest torch versions
# torch is throwing RuntimeErrors when not set
torch.stft = partial(torch.stft, return_complex=False)
def __exit__(self, *exc):
torch.stft = self.old_stft
def kl_divergence(pred_probs: torch.Tensor, target_probs: torch.Tensor, epsilon: float = 1e-6) -> torch.Tensor:
"""Computes the elementwise KL-Divergence loss between probability distributions
from generated samples and target samples.
Args:
pred_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on generated audio. Expected shape is [B, num_classes].
target_probs (torch.Tensor): Probabilities for each label obtained
from a classifier on target audio. Expected shape is [B, num_classes].
epsilon (float): Epsilon value.
Returns:
kld (torch.Tensor): KLD loss between each generated sample and target pair.
"""
kl_div = torch.nn.functional.kl_div((pred_probs + epsilon).log(), target_probs, reduction="none")
return kl_div.sum(-1)
class KLDivergenceMetric(torchmetrics.Metric):
"""Base implementation for KL Divergence metric.
The KL divergence is measured between probability distributions
of class predictions returned by a pre-trained audio classification model.
When the KL-divergence is low, the generated audio is expected to
have similar acoustic characteristics as the reference audio,
according to the classifier.
"""
def __init__(self):
super().__init__()
self.add_state("kld_pq_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_qp_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("kld_all_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("weight", default=torch.tensor(0), dist_reduce_fx="sum")
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor): Probabilities over labels, of shape [B, num_classes].
"""
raise NotImplementedError("implement method to extract label distributions from the model.")
def update(self, preds: torch.Tensor, targets: torch.Tensor,
sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
"""Calculates running KL-Divergence loss between batches of audio
preds (generated) and target (ground-truth)
Args:
preds (torch.Tensor): Audio samples to evaluate, of shape [B, C, T].
targets (torch.Tensor): Target samples to compare against, of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
"""
assert preds.shape == targets.shape
assert preds.size(0) > 0, "Cannot update the loss with empty tensors"
preds_probs = self._get_label_distribution(preds, sizes, sample_rates)
targets_probs = self._get_label_distribution(targets, sizes, sample_rates)
if preds_probs is not None and targets_probs is not None:
assert preds_probs.shape == targets_probs.shape
kld_scores = kl_divergence(preds_probs, targets_probs)
assert not torch.isnan(kld_scores).any(), "kld_scores contains NaN value(s)!"
self.kld_pq_sum += torch.sum(kld_scores)
kld_qp_scores = kl_divergence(targets_probs, preds_probs)
self.kld_qp_sum += torch.sum(kld_qp_scores)
self.weight += torch.tensor(kld_scores.size(0))
def compute(self) -> dict:
"""Computes KL-Divergence across all evaluated pred/target pairs."""
weight: float = float(self.weight.item()) # type: ignore
assert weight > 0, "Unable to compute with total number of comparisons <= 0"
logger.info(f"Computing KL divergence on a total of {weight} samples")
kld_pq = self.kld_pq_sum.item() / weight # type: ignore
kld_qp = self.kld_qp_sum.item() / weight # type: ignore
kld_both = kld_pq + kld_qp
return {'kld': kld_pq, 'kld_pq': kld_pq, 'kld_qp': kld_qp, 'kld_both': kld_both}
class PasstKLDivergenceMetric(KLDivergenceMetric):
"""KL-Divergence metric based on pre-trained PASST classifier on AudioSet.
From: PaSST: Efficient Training of Audio Transformers with Patchout
Paper: https://arxiv.org/abs/2110.05069
Implementation: https://github.com/kkoutini/PaSST
Follow instructions from the github repo:
```
pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'
```
Args:
pretrained_length (float, optional): Audio duration used for the pretrained model.
"""
def __init__(self, pretrained_length: tp.Optional[float] = None):
super().__init__()
self._initialize_model(pretrained_length)
def _initialize_model(self, pretrained_length: tp.Optional[float] = None):
"""Initialize underlying PaSST audio classifier."""
model, sr, max_frames, min_frames = self._load_base_model(pretrained_length)
self.min_input_frames = min_frames
self.max_input_frames = max_frames
self.model_sample_rate = sr
self.model = model
self.model.eval()
self.model.to(self.device)
def _load_base_model(self, pretrained_length: tp.Optional[float]):
"""Load pretrained model from PaSST."""
try:
if pretrained_length == 30:
from hear21passt.base30sec import get_basic_model # type: ignore
max_duration = 30
elif pretrained_length == 20:
from hear21passt.base20sec import get_basic_model # type: ignore
max_duration = 20
else:
from hear21passt.base import get_basic_model # type: ignore
# Original PASST was trained on AudioSet with 10s-long audio samples
max_duration = 10
min_duration = 0.15
min_duration = 0.15
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install hear21passt to compute KL divergence: ",
"pip install 'git+https://github.com/kkoutini/passt_hear21@0.0.19#egg=hear21passt'"
)
model_sample_rate = 32_000
max_input_frames = int(max_duration * model_sample_rate)
min_input_frames = int(min_duration * model_sample_rate)
with open(os.devnull, 'w') as f, contextlib.redirect_stdout(f):
model = get_basic_model(mode='logits')
return model, model_sample_rate, max_input_frames, min_input_frames
def _process_audio(self, wav: torch.Tensor, sample_rate: int, wav_len: int) -> tp.List[torch.Tensor]:
"""Process audio to feed to the pretrained model."""
wav = wav.unsqueeze(0)
wav = wav[..., :wav_len]
wav = convert_audio(wav, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1)
wav = wav.squeeze(0)
# we don't pad but return a list of audio segments as this otherwise affects the KLD computation
segments = torch.split(wav, self.max_input_frames, dim=-1)
valid_segments = []
for s in segments:
# ignoring too small segments that are breaking the model inference
if s.size(-1) > self.min_input_frames:
valid_segments.append(s)
return [s[None] for s in valid_segments]
def _get_model_preds(self, wav: torch.Tensor) -> torch.Tensor:
"""Run the pretrained model and get the predictions."""
assert wav.dim() == 3, f"Unexpected number of dims for preprocessed wav: {wav.shape}"
wav = wav.mean(dim=1)
# PaSST is printing a lot of garbage that we are not interested in
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
with torch.no_grad(), _patch_passt_stft():
logits = self.model(wav.to(self.device))
probs = torch.softmax(logits, dim=-1)
return probs
def _get_label_distribution(self, x: torch.Tensor, sizes: torch.Tensor,
sample_rates: torch.Tensor) -> tp.Optional[torch.Tensor]:
"""Get model output given provided input tensor.
Args:
x (torch.Tensor): Input audio tensor of shape [B, C, T].
sizes (torch.Tensor): Actual audio sample length, of shape [B].
sample_rates (torch.Tensor): Actual audio sample rate, of shape [B].
Returns:
probs (torch.Tensor, optional): Probabilities over labels, of shape [B, num_classes].
"""
all_probs: tp.List[torch.Tensor] = []
for i, wav in enumerate(x):
sample_rate = int(sample_rates[i].item())
wav_len = int(sizes[i].item())
wav_segments = self._process_audio(wav, sample_rate, wav_len)
for segment in wav_segments:
probs = self._get_model_preds(segment).mean(dim=0)
all_probs.append(probs)
if len(all_probs) > 0:
return torch.stack(all_probs, dim=0)
else:
return None