|
import torch |
|
import evaluate |
|
from transformers import AutoModel, AutoProcessor, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast |
|
|
|
|
|
def clap_similarity(clap_model_name_or_path, texts, audios, device): |
|
clap = AutoModel.from_pretrained(clap_model_name_or_path) |
|
clap_processor = AutoProcessor.from_pretrained(clap_model_name_or_path) |
|
clap_inputs = clap_processor(text=texts, audios=audios, padding=True, return_tensors="pt").to(device) |
|
clap.to(device) |
|
with torch.no_grad(): |
|
text_features = clap.get_text_features( |
|
clap_inputs["input_ids"], attention_mask=clap_inputs.get("attention_mask", None) |
|
) |
|
audio_features = clap.get_audio_features(clap_inputs["input_features"]) |
|
|
|
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) |
|
|
|
clap.to("cpu") |
|
clap_inputs.to("cpu") |
|
return cosine_sim.mean().to("cpu") |
|
|
|
|
|
def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_size, sampling_rate): |
|
metric = evaluate.load("wer") |
|
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device) |
|
|
|
return_language = None |
|
if isinstance(asr_pipeline.model, WhisperForConditionalGeneration): |
|
return_language = True |
|
|
|
transcriptions = asr_pipeline( |
|
[{"raw": audio, "sampling_rate": sampling_rate} for audio in audios], |
|
batch_size=int(per_device_eval_batch_size), |
|
return_language=return_language, |
|
) |
|
|
|
if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)): |
|
tokenizer = asr_pipeline.tokenizer |
|
else: |
|
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3") |
|
|
|
english_normalizer = tokenizer.normalize |
|
basic_normalizer = tokenizer.basic_normalize |
|
|
|
normalized_predictions = [] |
|
normalized_references = [] |
|
|
|
for pred, ref in zip(transcriptions, prompts): |
|
normalizer = english_normalizer if return_language and pred["chunks"][0]["language"] == "english" else basic_normalizer |
|
norm_ref = normalizer(ref) |
|
if len(norm_ref) > 0: |
|
norm_pred = normalizer(pred["text"]) |
|
normalized_predictions.append(norm_pred) |
|
normalized_references.append(norm_ref) |
|
|
|
word_error = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references) |
|
|
|
return word_error, [t["text"] for t in transcriptions] |
|
|