sanchit-gandhi's picture
repeated n-grams
eed20cf
raw
history blame
No virus
5.47 kB
import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
from nltk import ngrams
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`.")
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
all_ngrams = list(ngrams(text2.split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (sampling_rate, array), wer_percentage, ier_percentage, repeated_ngrams, full_text
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["large-v2", *table[0]]
table[1] = ["large-32-2", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Whisper Transcription Analysis
</h1>
</div>
</div>
"""
)
gr.Markdown(
"Analyse the transcriptions generated by the Whisper large-v2 and large-32-2 models on the TEDLIUM dev set."
"The transcriptions for both models are shown at the bottom of the demo. The text diff for each is computed "
"relative to the target transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>, and "
"deletions in <span style='background-color:#FFCCCB'><s>red</s></span>."
)
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
height=1000,
)
with gr.Row():
gr.Markdown("**large-v2 text diff**")
gr.Markdown("**large-32-2 text diff**")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()