File size: 7,337 Bytes
43380b2
af4ff35
aff6746
af4ff35
89fac21
43380b2
af4ff35
45347ae
 
 
aff6746
43380b2
aff6746
45347ae
 
 
aff6746
43380b2
905ed31
45347ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7115c2e
45347ae
 
 
01e4e5c
 
45347ae
dd10290
01e4e5c
dd10290
 
 
45347ae
d41eee2
dd10290
 
d41eee2
 
dd10290
45347ae
 
d41eee2
01e4e5c
45347ae
01e4e5c
d41eee2
 
 
 
01e4e5c
8c82d13
45347ae
7115c2e
dd10290
7115c2e
45347ae
7115c2e
 
 
 
 
 
 
dd10290
7115c2e
45347ae
7115c2e
 
 
 
 
 
 
01e4e5c
 
45347ae
01e4e5c
45347ae
 
 
 
 
 
01e4e5c
45347ae
01e4e5c
45347ae
01e4e5c
45347ae
01e4e5c
 
 
ea825ae
 
613c9f7
 
 
 
 
ea825ae
 
 
 
 
 
 
 
 
45347ae
ea825ae
 
 
 
 
dc5a821
 
aff6746
dc5a821
45347ae
aff6746
dc5a821
 
 
 
 
 
45347ae
 
 
 
 
 
7115c2e
 
 
 
45347ae
 
 
 
 
7115c2e
 
 
 
 
aff6746
 
45347ae
0115245
45347ae
0115245
45347ae
 
 
 
0115245
45347ae
 
 
 
aff6746
 
af4ff35
aff6746
0115245
aff6746
 
b4734b3
aff6746
 
45347ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import spaces
import jiwer
import numpy as np
import gradio as gr

@spaces.GPU()
def calculate_wer(reference, hypothesis):
    reference_str = " ".join(reference)
    hypothesis_str = " ".join(hypothesis)
    return jiwer.wer(reference_str, hypothesis_str)

@spaces.GPU()
def calculate_cer(reference, hypothesis):
    reference_str = " ".join(reference)
    hypothesis_str = " ".join(hypothesis)
    return jiwer.cer(reference_str, hypothesis_str)

@spaces.GPU()    
def calculate_sentence_metrics(reference, hypothesis):
    reference_sentences = [line.strip() for line in reference]
    hypothesis_sentences = [line.strip() for line in hypothesis]

    sentence_wers = []
    sentence_cers = []

    min_length = min(len(reference_sentences), len(hypothesis_sentences))

    for i in range(min_length):
        ref = reference_sentences[i]
        hyp = hypothesis_sentences[i]
        wer = jiwer.wer(ref, hyp)
        cer = jiwer.cer(ref, hyp)
        sentence_wers.append(wer)
        sentence_cers.append(cer)

    average_wer = np.mean(sentence_wers) if sentence_wers else 0.0
    std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0
    average_cer = np.mean(sentence_cers) if sentence_cers else 0.0
    std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0

    return {
        "sentence_wers": sentence_wers,
        "sentence_cers": sentence_cers,
        "average_wer": average_wer,
        "average_cer": average_cer,
        "std_dev_wer": std_dev_wer,
        "std_dev_cer": std_dev_cer
    }

def identify_misaligned_sentences(reference, hypothesis):
    reference_sentences = [line.strip() for line in reference]
    hypothesis_sentences = [line.strip() for line in hypothesis]

    misaligned = []

    for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
        if ref != hyp:
            ref_words = ref.split()
            hyp_words = hyp.split()
            min_length = min(len(ref_words), len(hyp_words))

            misalignment_start = 0
            for j in range(min_length):
                if ref_words[j] != hyp_words[j]:
                    misalignment_start = j
                    break

            context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else ""
            context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else ""

            misaligned.append({
                "index": i + 1,
                "reference": ref,
                "hypothesis": hyp,
                "misalignment_start": misalignment_start,
                "context_ref": context_ref,
                "context_hyp": context_hyp
            })

    # Handle extra sentences
    if len(reference_sentences) > len(hypothesis_sentences):
        for i in range(len(hypothesis_sentences), len(reference_sentences)):
            misaligned.append({
                "index": i + 1,
                "reference": reference_sentences[i],
                "hypothesis": "No corresponding sentence",
                "misalignment_start": 0,
                "context_ref": reference_sentences[i],
                "context_hyp": "No corresponding sentence"
            })
    elif len(hypothesis_sentences) > len(reference_sentences):
        for i in range(len(reference_sentences), len(hypothesis_sentences)):
            misaligned.append({
                "index": i + 1,
                "reference": "No corresponding sentence",
                "hypothesis": hypothesis_sentences[i],
                "misalignment_start": 0,
                "context_ref": "No corresponding sentence",
                "context_hyp": hypothesis_sentences[i]
            })

    return misaligned

def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer):
    md = "### Sentence-level Metrics\n\n"
    md += f"**Average WER**: {average_wer:.2f}\n\n"
    md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n"
    md += f"**Average CER**: {average_cer:.2f}\n\n"
    md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n"

    md += "---\n**WER by Sentence**\n"
    for i, wer in enumerate(sentence_wers):
        md += f"- Sentence {i+1}: {wer:.2f}\n"

    md += "\n**CER by Sentence**\n"
    for i, cer in enumerate(sentence_cers):
        md += f"- Sentence {i+1}: {cer:.2f}\n"

    return md

def process_files(reference_file, hypothesis_file):
    try:
        with open(reference_file.name, 'r', encoding='utf-8') as f:
            reference_text = f.read().splitlines()

        with open(hypothesis_file.name, 'r', encoding='utf-8') as f:
            hypothesis_text = f.read().splitlines()

        overall_wer = calculate_wer(reference_text, hypothesis_text)
        overall_cer = calculate_cer(reference_text, hypothesis_text)
        sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text)
        misaligned = identify_misaligned_sentences(reference_text, hypothesis_text)

        return {
            "Overall WER": overall_wer,
            "Overall CER": overall_cer,
            **sentence_metrics,
            "Misaligned Sentences": misaligned
        }
    except Exception as e:
        return {"error": str(e)}

def process_and_display(ref_file, hyp_file):
    result = process_files(ref_file, hyp_file)

    if "error" in result:
        return {"error": result["error"]}, "", ""

    metrics = {
        "Overall WER": result["Overall WER"],
        "Overall CER": result["Overall CER"]
    }

    metrics_md = format_sentence_metrics(
        result["sentence_wers"],
        result["sentence_cers"],
        result["average_wer"],
        result["average_cer"],
        result["std_dev_wer"],
        result["std_dev_cer"]
    )

    misaligned_md = "### Misaligned Sentences\n\n"
    if result["Misaligned Sentences"]:
        for mis in result["Misaligned Sentences"]:
            misaligned_md += f"**Sentence {mis['index']}**\n"
            misaligned_md += f"- Reference: {mis['context_ref']}\n"
            misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n"
            misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n"
    else:
        misaligned_md += "* No misaligned sentences found."

    return metrics, metrics_md, misaligned_md

def main():
    with gr.Blocks() as demo:
        gr.Markdown("# πŸ“Š ASR Metrics Analysis Tool")

        with gr.Row():
            with gr.Column():
                gr.Markdown("### Upload your reference and hypothesis files")
                reference_file = gr.File(label="Reference File (.txt)")
                hypothesis_file = gr.File(label="Hypothesis File (.txt)")
                compute_button = gr.Button("Compute Metrics", variant="primary")

            with gr.Column():
                results_output = gr.JSON(label="Results Summary")
                metrics_output = gr.Markdown(label="Sentence Metrics")
                misaligned_output = gr.Markdown(label="Misaligned Sentences")

        compute_button.click(
            fn=process_and_display,
            inputs=[reference_file, hypothesis_file],
            outputs=[results_output, metrics_output, misaligned_output]
        )

    demo.launch(ssr_mode=False)

if __name__ == "__main__":
    main()