asr_metrics / app.py
akki2825's picture
change readlines
613c9f7 verified
import spaces
import jiwer
import numpy as np
import gradio as gr
@spaces.GPU()
def calculate_wer(reference, hypothesis):
reference_str = " ".join(reference)
hypothesis_str = " ".join(hypothesis)
return jiwer.wer(reference_str, hypothesis_str)
@spaces.GPU()
def calculate_cer(reference, hypothesis):
reference_str = " ".join(reference)
hypothesis_str = " ".join(hypothesis)
return jiwer.cer(reference_str, hypothesis_str)
@spaces.GPU()
def calculate_sentence_metrics(reference, hypothesis):
reference_sentences = [line.strip() for line in reference]
hypothesis_sentences = [line.strip() for line in hypothesis]
sentence_wers = []
sentence_cers = []
min_length = min(len(reference_sentences), len(hypothesis_sentences))
for i in range(min_length):
ref = reference_sentences[i]
hyp = hypothesis_sentences[i]
wer = jiwer.wer(ref, hyp)
cer = jiwer.cer(ref, hyp)
sentence_wers.append(wer)
sentence_cers.append(cer)
average_wer = np.mean(sentence_wers) if sentence_wers else 0.0
std_dev_wer = np.std(sentence_wers) if sentence_wers else 0.0
average_cer = np.mean(sentence_cers) if sentence_cers else 0.0
std_dev_cer = np.std(sentence_cers) if sentence_cers else 0.0
return {
"sentence_wers": sentence_wers,
"sentence_cers": sentence_cers,
"average_wer": average_wer,
"average_cer": average_cer,
"std_dev_wer": std_dev_wer,
"std_dev_cer": std_dev_cer
}
def identify_misaligned_sentences(reference, hypothesis):
reference_sentences = [line.strip() for line in reference]
hypothesis_sentences = [line.strip() for line in hypothesis]
misaligned = []
for i, (ref, hyp) in enumerate(zip(reference_sentences, hypothesis_sentences)):
if ref != hyp:
ref_words = ref.split()
hyp_words = hyp.split()
min_length = min(len(ref_words), len(hyp_words))
misalignment_start = 0
for j in range(min_length):
if ref_words[j] != hyp_words[j]:
misalignment_start = j
break
context_ref = ' '.join(ref_words[:misalignment_start] + ['**' + ref_words[misalignment_start] + '**']) if ref_words else ""
context_hyp = ' '.join(hyp_words[:misalignment_start] + ['**' + hyp_words[misalignment_start] + '**']) if hyp_words else ""
misaligned.append({
"index": i + 1,
"reference": ref,
"hypothesis": hyp,
"misalignment_start": misalignment_start,
"context_ref": context_ref,
"context_hyp": context_hyp
})
# Handle extra sentences
if len(reference_sentences) > len(hypothesis_sentences):
for i in range(len(hypothesis_sentences), len(reference_sentences)):
misaligned.append({
"index": i + 1,
"reference": reference_sentences[i],
"hypothesis": "No corresponding sentence",
"misalignment_start": 0,
"context_ref": reference_sentences[i],
"context_hyp": "No corresponding sentence"
})
elif len(hypothesis_sentences) > len(reference_sentences):
for i in range(len(reference_sentences), len(hypothesis_sentences)):
misaligned.append({
"index": i + 1,
"reference": "No corresponding sentence",
"hypothesis": hypothesis_sentences[i],
"misalignment_start": 0,
"context_ref": "No corresponding sentence",
"context_hyp": hypothesis_sentences[i]
})
return misaligned
def format_sentence_metrics(sentence_wers, sentence_cers, average_wer, average_cer, std_dev_wer, std_dev_cer):
md = "### Sentence-level Metrics\n\n"
md += f"**Average WER**: {average_wer:.2f}\n\n"
md += f"**Standard Deviation WER**: {std_dev_wer:.2f}\n\n"
md += f"**Average CER**: {average_cer:.2f}\n\n"
md += f"**Standard Deviation CER**: {std_dev_cer:.2f}\n\n"
md += "---\n**WER by Sentence**\n"
for i, wer in enumerate(sentence_wers):
md += f"- Sentence {i+1}: {wer:.2f}\n"
md += "\n**CER by Sentence**\n"
for i, cer in enumerate(sentence_cers):
md += f"- Sentence {i+1}: {cer:.2f}\n"
return md
def process_files(reference_file, hypothesis_file):
try:
with open(reference_file.name, 'r', encoding='utf-8') as f:
reference_text = f.read().splitlines()
with open(hypothesis_file.name, 'r', encoding='utf-8') as f:
hypothesis_text = f.read().splitlines()
overall_wer = calculate_wer(reference_text, hypothesis_text)
overall_cer = calculate_cer(reference_text, hypothesis_text)
sentence_metrics = calculate_sentence_metrics(reference_text, hypothesis_text)
misaligned = identify_misaligned_sentences(reference_text, hypothesis_text)
return {
"Overall WER": overall_wer,
"Overall CER": overall_cer,
**sentence_metrics,
"Misaligned Sentences": misaligned
}
except Exception as e:
return {"error": str(e)}
def process_and_display(ref_file, hyp_file):
result = process_files(ref_file, hyp_file)
if "error" in result:
return {"error": result["error"]}, "", ""
metrics = {
"Overall WER": result["Overall WER"],
"Overall CER": result["Overall CER"]
}
metrics_md = format_sentence_metrics(
result["sentence_wers"],
result["sentence_cers"],
result["average_wer"],
result["average_cer"],
result["std_dev_wer"],
result["std_dev_cer"]
)
misaligned_md = "### Misaligned Sentences\n\n"
if result["Misaligned Sentences"]:
for mis in result["Misaligned Sentences"]:
misaligned_md += f"**Sentence {mis['index']}**\n"
misaligned_md += f"- Reference: {mis['context_ref']}\n"
misaligned_md += f"- Hypothesis: {mis['context_hyp']}\n"
misaligned_md += f"- Misalignment starts at position: {mis['misalignment_start']}\n\n"
else:
misaligned_md += "* No misaligned sentences found."
return metrics, metrics_md, misaligned_md
def main():
with gr.Blocks() as demo:
gr.Markdown("# πŸ“Š ASR Metrics Analysis Tool")
with gr.Row():
with gr.Column():
gr.Markdown("### Upload your reference and hypothesis files")
reference_file = gr.File(label="Reference File (.txt)")
hypothesis_file = gr.File(label="Hypothesis File (.txt)")
compute_button = gr.Button("Compute Metrics", variant="primary")
with gr.Column():
results_output = gr.JSON(label="Results Summary")
metrics_output = gr.Markdown(label="Sentence Metrics")
misaligned_output = gr.Markdown(label="Misaligned Sentences")
compute_button.click(
fn=process_and_display,
inputs=[reference_file, hypothesis_file],
outputs=[results_output, metrics_output, misaligned_output]
)
demo.launch(ssr_mode=False)
if __name__ == "__main__":
main()