Hecheng0625's picture
Upload 167 files
8c92a11 verified
raw
history blame
2.97 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torchmetrics import WordErrorRate
def extract_wer(
model,
**kwargs,
):
"""Compute Word Error Rate (WER) between the predicted and the ground truth audio.
content_gt: the ground truth content.
audio_ref: path to the ground truth audio.
audio_deg: path to the predicted audio.
mode: "gt_content" computes the WER between the predicted content obtained from the whisper model and the ground truth content.
both content_gt and audio_deg are needed.
"gt_audio" computes the WER between the extracted ground truth and predicted contents obtained from the whisper model.
both audio_ref and audio_deg are needed.
"""
kwargs = kwargs["kwargs"]
mode = kwargs["intelligibility_mode"]
language = kwargs["language"]
wer = WordErrorRate()
if torch.cuda.is_available():
device = torch.device("cuda")
wer = wer.to(device)
# Get ground truth content
if mode == "gt_content":
content_gt = kwargs["content_gt"]
audio_deg = kwargs["audio_deg"]
if language == "chinese":
prompt = "ไปฅไธ‹ๆ˜ฏๆ™ฎ้€š่ฏ็š„ๅฅๅญ"
result_deg = model.transcribe(
audio_deg, language="zh", verbose=True, initial_prompt=prompt
)
else:
result_deg = model.transcribe(audio_deg, verbose=True)
elif mode == "gt_audio":
audio_ref = kwargs["audio_ref"]
audio_deg = kwargs["audio_deg"]
if language == "chinese":
prompt = "ไปฅไธ‹ๆ˜ฏๆ™ฎ้€š่ฏ็š„ๅฅๅญ"
result_ref = model.transcribe(
audio_ref, language="zh", verbose=True, initial_prompt=prompt
)
result_deg = model.transcribe(
audio_deg, language="zh", verbose=True, initial_prompt=prompt
)
else:
result_ref = model.transcribe(audio_deg, verbose=True)
result_deg = model.transcribe(audio_deg, verbose=True)
content_gt = result_ref["text"]
content_gt = content_gt.replace(" ", "")
content_gt = content_gt.replace(".", "")
content_gt = content_gt.replace("'", "")
content_gt = content_gt.replace("-", "")
content_gt = content_gt.replace(",", "")
content_gt = content_gt.replace("!", "")
content_gt = content_gt.lower()
# Get predicted truth content
content_pred = result_deg["text"]
content_pred = content_pred.replace(" ", "")
content_pred = content_pred.replace(".", "")
content_pred = content_pred.replace("'", "")
content_pred = content_pred.replace("-", "")
content_pred = content_pred.replace(",", "")
content_pred = content_pred.replace("!", "")
content_pred = content_pred.lower()
return wer(content_pred, content_gt).detach().cpu().numpy().tolist()