sanchit-gandhi's picture
up
8a02493
raw
history blame
2.86 kB
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset("distil-whisper/tedlium-long-form", split="validation")
csv = pd.read_csv("assets/large-v2.csv")
norm_target = csv["Norm Target"]
norm_pred = csv["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred = [norm_pred[i] for i in range(len(norm_pred))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_visualisation(idx):
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
text2 = norm_pred[idx]
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = 100 * wer_output.wer
num_insertions = wer_output.insertions
rel_length = len(text2.split()) / len(text1.split())
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (sampling_rate, array), wer_percentage, num_insertions, rel_length, full_text
if __name__ == "__main__":
with gr.Blocks() as demo:
slider = gr.Slider(
minimum=0, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Row():
wer = gr.Number(label="WER")
insertions = gr.Number(label="Insertions")
relative_length = gr.Number(label="Relative length of target / reference")
text_out = gr.Markdown(label="Text difference")
btn.click(
fn=get_visualisation,
inputs=slider,
outputs=[audio_out, wer, insertions, relative_length, text_out],
)
demo.launch()