sanchit-gandhi HF staff commited on
Commit
e676bd8
1 Parent(s): 9d85ee2

three tabs

Browse files
Files changed (1) hide show
  1. app.py +94 -32
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import numpy as np
4
  import unicodedata
@@ -46,63 +47,124 @@ dataset = load_dataset(
46
  "distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
47
  )
48
 
49
- csv = pd.read_csv("assets/large-v2.csv")
50
 
51
- norm_target = csv["Norm Target"]
52
- norm_pred = csv["Norm Pred"]
53
 
54
  norm_target = [norm_target[i] for i in range(len(norm_target))]
55
- norm_pred = [norm_pred[i] for i in range(len(norm_pred))]
 
 
 
 
 
56
 
57
  target_dtype = np.int16
58
  max_range = np.iinfo(target_dtype).max
59
 
60
 
61
- def get_visualisation(idx):
62
  idx -= 1
63
  audio = dataset[idx]["audio"]
64
  array = (audio["array"] * max_range).astype(np.int16)
65
  sampling_rate = audio["sampling_rate"]
66
 
67
  text1 = norm_target[idx]
68
- text2 = norm_pred[idx]
69
 
70
  wer_output = process_words(text1, text2, wer_default, wer_default)
71
- wer_percentage = 100 * wer_output.wer
72
- rel_insertions = wer_output.insertions / len(text1.split())
73
 
74
- rel_length = len(text2.split()) / len(text1.split())
75
 
76
  diff = compare_string(text1, text2)
77
  full_text = style_text(diff)
78
 
79
- return (sampling_rate, array), wer_percentage, rel_insertions, rel_length, full_text
 
 
 
 
 
 
 
 
80
 
81
 
82
  if __name__ == "__main__":
83
- gr.Markdown(
84
- "Analyse the transcriptions generated by the Whisper large-v2 model on the TEDLIUM dev set."
85
- )
86
  with gr.Blocks() as demo:
87
- slider = gr.Slider(
88
- minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
89
- )
90
- btn = gr.Button("Analyse")
91
- audio_out = gr.Audio(label="Audio input")
92
- with gr.Row():
93
- wer = gr.Number(label="WER")
94
- relative_insertions = gr.Number(
95
- label="Relative insertions (# insertions / target length)"
96
- )
97
- relative_length = gr.Number(
98
- label="Relative length (reference length / target length)"
99
  )
100
- text_out = gr.Markdown(label="Text difference")
101
-
102
- btn.click(
103
- fn=get_visualisation,
104
- inputs=slider,
105
- outputs=[audio_out, wer, relative_insertions, relative_length, text_out],
106
- )
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  demo.launch()
 
1
  import os
2
+ from functools import partial
3
 
4
  import numpy as np
5
  import unicodedata
 
47
  "distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
48
  )
49
 
50
+ csv_v2 = pd.read_csv("assets/large-v2.csv")
51
 
52
+ norm_target = csv_v2["Norm Target"]
53
+ norm_pred_v2 = csv_v2["Norm Pred"]
54
 
55
  norm_target = [norm_target[i] for i in range(len(norm_target))]
56
+ norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
57
+
58
+ csv_v2 = pd.read_csv("assets/large-32-2.csv")
59
+
60
+ norm_pred_32_2 = csv_v2["Norm Pred"]
61
+ norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
62
 
63
  target_dtype = np.int16
64
  max_range = np.iinfo(target_dtype).max
65
 
66
 
67
+ def get_visualisation(idx, model="v2"):
68
  idx -= 1
69
  audio = dataset[idx]["audio"]
70
  array = (audio["array"] * max_range).astype(np.int16)
71
  sampling_rate = audio["sampling_rate"]
72
 
73
  text1 = norm_target[idx]
74
+ text2 = norm_pred_v2[idx] if model == "v2" else norm_pred_32_2[idx]
75
 
76
  wer_output = process_words(text1, text2, wer_default, wer_default)
77
+ wer_percentage = round(100 * wer_output.wer, 2)
78
+ ier_percentage = round(100 * wer_output.insertions / len(wer_output.references[0]), 2)
79
 
80
+ rel_length = round(len(text2.split()) / len(text1.split()), 2)
81
 
82
  diff = compare_string(text1, text2)
83
  full_text = style_text(diff)
84
 
85
+ return (sampling_rate, array), wer_percentage, ier_percentage, rel_length, full_text
86
+
87
+ def get_side_by_side_visualisation(idx):
88
+ large_v2 = get_visualisation(idx, model="v2")
89
+ large_32_2 = get_visualisation(idx, model="32-2")
90
+ table = [large_v2[1:-1], large_32_2[1:-1]]
91
+ table[0] = ["large-v2", *table[0]]
92
+ table[1] = ["large-32-2", *table[1]]
93
+ return large_v2[0], table, large_v2[-1], large_32_2[-1]
94
 
95
 
96
  if __name__ == "__main__":
 
 
 
97
  with gr.Blocks() as demo:
98
+ with gr.Tab("large-v2"):
99
+ gr.Markdown(
100
+ "Analyse the transcriptions generated by the Whisper large-v2 model on the TEDLIUM dev set."
 
 
 
 
 
 
 
 
 
101
  )
 
 
 
 
 
 
 
102
 
103
+ slider = gr.Slider(
104
+ minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
105
+ )
106
+ btn = gr.Button("Analyse")
107
+ audio_out = gr.Audio(label="Audio input")
108
+ with gr.Row():
109
+ wer = gr.Number(label="Word Error Rate (WER)")
110
+ ier = gr.Number(
111
+ label="Insertion Error Rate (IER)"
112
+ )
113
+ relative_length = gr.Number(
114
+ label="Relative length (reference length / target length)"
115
+ )
116
+ text_out = gr.Markdown(label="Text difference")
117
+
118
+ btn.click(
119
+ fn=partial(get_visualisation, model="v2"),
120
+ inputs=slider,
121
+ outputs=[audio_out, wer, ier, relative_length, text_out],
122
+ )
123
+ with gr.Tab("large-32-2"):
124
+ gr.Markdown(
125
+ "Analyse the transcriptions generated by the Whisper large-32-2 model on the TEDLIUM dev set."
126
+ )
127
+ slider = gr.Slider(
128
+ minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
129
+ )
130
+ btn = gr.Button("Analyse")
131
+ audio_out = gr.Audio(label="Audio input")
132
+ with gr.Row():
133
+ wer = gr.Number(label="Word Error Rate (WER)")
134
+ ier = gr.Number(
135
+ label="Insertion Error Rate (IER)"
136
+ )
137
+ relative_length = gr.Number(
138
+ label="Relative length (reference length / target length)"
139
+ )
140
+ text_out = gr.Markdown(label="Text difference")
141
+
142
+ btn.click(
143
+ fn=partial(get_visualisation, model="32-2"),
144
+ inputs=slider,
145
+ outputs=[audio_out, wer, ier, relative_length, text_out],
146
+ )
147
+ with gr.Tab("side-by-side"):
148
+ gr.Markdown(
149
+ "Analyse the transcriptions generated by the Whisper large-32-2 model on the TEDLIUM dev set."
150
+ )
151
+ slider = gr.Slider(
152
+ minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
153
+ )
154
+ btn = gr.Button("Analyse")
155
+ audio_out = gr.Audio(label="Audio input")
156
+ with gr.Column():
157
+ table = gr.Dataframe(headers=["Model", "Word Error Rate (WER)", "Insertion Error Rate (IER)", "Rel length (ref length / tgt length)"], height=1000)
158
+ with gr.Row():
159
+ gr.Markdown("large-v2 text diff")
160
+ gr.Markdown("large-32-2 text diff")
161
+ with gr.Row():
162
+ text_out_v2 = gr.Markdown(label="Text difference")
163
+ text_out_32_2 = gr.Markdown(label="Text difference")
164
+
165
+ btn.click(
166
+ fn=get_side_by_side_visualisation,
167
+ inputs=slider,
168
+ outputs=[audio_out, table, text_out_v2, text_out_32_2],
169
+ )
170
  demo.launch()