File size: 3,286 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import whisper

from torchmetrics import WordErrorRate


def extract_wer(
    content_gt=None,
    audio_ref=None,
    audio_deg=None,
    fs=None,
    language="chinese",
    remove_space=True,
    remove_punctuation=True,
    mode="gt_audio",
):
    """Compute Word Error Rate (WER) between the predicted and the ground truth audio.
    content_gt: the ground truth content.
    audio_ref: path to the ground truth audio.
    audio_deg: path to the predicted audio.
    mode: "gt_content" computes the WER between the predicted content obtained from the whisper model and the ground truth content.
          both content_gt and audio_deg are needed.
          "gt_audio" computes the WER between the extracted ground truth and predicted contents obtained from the whisper model.
          both audio_ref and audio_deg are needed.
    """
    # Get ground truth content
    if mode == "gt_content":
        assert content_gt != None
        if language == "chinese":
            prompt = "以下是普通话的句子"
            model = whisper.load_model("large").cuda()
            result_deg = model.transcribe(
                audio_deg, language="zh", verbose=True, initial_prompt=prompt
            )
        elif language == "english":
            model = whisper.load_model("large").cuda()
            result_deg = model.transcribe(audio_deg, language="en", verbose=True)
    elif mode == "gt_audio":
        assert audio_ref != None
        if language == "chinese":
            prompt = "以下是普通话的句子"
            model = whisper.load_model("large").cuda()
            result_ref = model.transcribe(
                audio_ref, language="zh", verbose=True, initial_prompt=prompt
            )
            result_deg = model.transcribe(
                audio_deg, language="zh", verbose=True, initial_prompt=prompt
            )
        elif language == "english":
            model = whisper.load_model("large").cuda()
            result_ref = model.transcribe(audio_deg, language="en", verbose=True)
            result_deg = model.transcribe(audio_deg, language="en", verbose=True)
        content_gt = result_ref["text"]
        if remove_space:
            content_gt = content_gt.replace(" ", "")
        if remove_punctuation:
            content_gt = content_gt.replace(".", "")
            content_gt = content_gt.replace("'", "")
            content_gt = content_gt.replace("-", "")
            content_gt = content_gt.replace(",", "")
            content_gt = content_gt.replace("!", "")
            content_gt = content_gt.lower()

    # Get predicted content
    content_pred = result_deg["text"]
    if remove_space:
        content_pred = content_pred.replace(" ", "")
    if remove_punctuation:
        content_pred = content_pred.replace(".", "")
        content_pred = content_pred.replace("'", "")
        content_pred = content_pred.replace("-", "")
        content_pred = content_pred.replace(",", "")
        content_pred = content_pred.replace("!", "")
        content_pred = content_pred.lower()
    wer = WordErrorRate()

    return wer(content_pred, content_gt).numpy().tolist()