|
import numpy as np |
|
import unicodedata |
|
import diff_match_patch as dmp_module |
|
from enum import Enum |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import pandas as pd |
|
from jiwer import process_words, wer_default |
|
|
|
|
|
class Action(Enum): |
|
INSERTION = 1 |
|
DELETION = -1 |
|
EQUAL = 0 |
|
|
|
|
|
def compare_string(text1: str, text2: str) -> list: |
|
text1_normalized = unicodedata.normalize("NFKC", text1) |
|
text2_normalized = unicodedata.normalize("NFKC", text2) |
|
|
|
dmp = dmp_module.diff_match_patch() |
|
diff = dmp.diff_main(text1_normalized, text2_normalized) |
|
dmp.diff_cleanupSemantic(diff) |
|
|
|
return diff |
|
|
|
|
|
def style_text(diff): |
|
fullText = "" |
|
for action, text in diff: |
|
if action == Action.INSERTION.value: |
|
fullText += f"<span style='background-color:Lightgreen'>{text}</span>" |
|
elif action == Action.DELETION.value: |
|
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>" |
|
elif action == Action.EQUAL.value: |
|
fullText += f"{text}" |
|
else: |
|
raise Exception("Not Implemented") |
|
fullText = fullText.replace("](", "]\(").replace("~", "\~") |
|
return fullText |
|
|
|
|
|
dataset = load_dataset("distil-whisper/tedlium-long-form", split="validation") |
|
|
|
csv = pd.read_csv("assets/large-v2.csv") |
|
|
|
norm_target = csv["Norm Target"] |
|
norm_pred = csv["Norm Pred"] |
|
|
|
norm_target = [norm_target[i] for i in range(len(norm_target))] |
|
norm_pred = [norm_pred[i] for i in range(len(norm_pred))] |
|
|
|
target_dtype = np.int16 |
|
max_range = np.iinfo(target_dtype).max |
|
|
|
|
|
def get_visualisation(idx): |
|
audio = dataset[idx]["audio"] |
|
array = (audio["array"] * max_range).astype(np.int16) |
|
sampling_rate = audio["sampling_rate"] |
|
|
|
text1 = norm_target[idx] |
|
text2 = norm_pred[idx] |
|
|
|
wer_output = process_words(text1, text2, wer_default, wer_default) |
|
wer_percentage = 100 * wer_output.wer |
|
num_insertions = wer_output.insertions |
|
|
|
rel_length = len(text2.split()) / len(text1.split()) |
|
|
|
diff = compare_string(text1, text2) |
|
full_text = style_text(diff) |
|
|
|
return (sampling_rate, array), wer_percentage, num_insertions, rel_length, full_text |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
slider = gr.Slider( |
|
minimum=0, maximum=len(norm_target), step=1, label="Dataset sample" |
|
) |
|
btn = gr.Button("Analyse") |
|
audio_out = gr.Audio(label="Audio input") |
|
with gr.Row(): |
|
wer = gr.Number(label="WER") |
|
insertions = gr.Number(label="Insertions") |
|
relative_length = gr.Number(label="Relative length of target / reference") |
|
text_out = gr.Markdown(label="Text difference") |
|
|
|
btn.click( |
|
fn=get_visualisation, |
|
inputs=slider, |
|
outputs=[audio_out, wer, insertions, relative_length, text_out], |
|
) |
|
|
|
demo.launch() |
|
|