|
|
|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
import typing as tp |
|
|
|
import torch |
|
import torchmetrics |
|
from transformers import RobertaTokenizer |
|
|
|
from ..data.audio_utils import convert_audio |
|
from ..environment import AudioCraftEnvironment |
|
from ..utils.utils import load_clap_state_dict |
|
|
|
try: |
|
import laion_clap |
|
except ImportError: |
|
laion_clap = None |
|
|
|
|
|
class TextConsistencyMetric(torchmetrics.Metric): |
|
"""Text consistency metric measuring consistency between audio and text pairs.""" |
|
|
|
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: |
|
raise NotImplementedError("implement how to update the metric from the audio and text pairs.") |
|
|
|
def compute(self): |
|
raise NotImplementedError("implement how to compute the final metric score.") |
|
|
|
|
|
class CLAPTextConsistencyMetric(TextConsistencyMetric): |
|
"""Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). |
|
|
|
This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) |
|
or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). |
|
|
|
As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the |
|
similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as |
|
well as the generated audio based on them, and define the MCC metric as the average cosine similarity |
|
between these embeddings. |
|
|
|
Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP |
|
""" |
|
def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): |
|
super().__init__() |
|
if laion_clap is None: |
|
raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") |
|
self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") |
|
self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") |
|
self._initialize_model(model_path, model_arch, enable_fusion) |
|
|
|
def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): |
|
model_path = AudioCraftEnvironment.resolve_reference_path(model_path) |
|
self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') |
|
self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) |
|
self.model_sample_rate = 48_000 |
|
load_clap_state_dict(self.model, model_path) |
|
self.model.eval() |
|
|
|
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: |
|
|
|
return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") |
|
|
|
def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: |
|
"""Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" |
|
assert audio.size(0) == len(text), "Number of audio and text samples should match" |
|
assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" |
|
sample_rate = int(sample_rates[0].item()) |
|
|
|
audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) |
|
audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) |
|
text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) |
|
|
|
cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) |
|
self.cosine_sum += cosine_sim.sum(dim=0) |
|
self.weight += torch.tensor(cosine_sim.size(0)) |
|
|
|
def compute(self): |
|
"""Computes the average cosine similarty across all audio/text pairs.""" |
|
assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" |
|
return (self.cosine_sum / self.weight).item() |
|
|