mboushaba commited on
Commit
e779c90
1 Parent(s): 460f7e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -27
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import os
2
-
3
  import gradio as gr
4
  from datasets import Audio
5
  from datasets import load_dataset
@@ -9,7 +7,8 @@ from transformers import pipeline
9
  from arabic_normalizer import ArabicTextNormalizer
10
 
11
  # Load dataset
12
- common_voice = load_dataset("mozilla-foundation/common_voice_11_0",trust_remote_code=True, name = "ar", split = "train")
 
13
  # select column that will be used
14
  common_voice = common_voice.select_columns(["audio", "sentence"])
15
 
@@ -21,7 +20,10 @@ generate_kwargs = {
21
  asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
22
  generate_kwargs = generate_kwargs)
23
  asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
24
- generate_kwargs = generate_kwargs)
 
 
 
25
  normalizer = ArabicTextNormalizer()
26
 
27
 
@@ -54,14 +56,20 @@ def generate_audio(index = None):
54
  "sampling_rate": audio["sampling_rate"]
55
  }
56
 
 
 
 
 
 
57
  # Perform automatic speech recognition (ASR) directly on the resampled audio array
58
  asr_output = asr_whisper_large(audio_data)
59
-
60
  asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
 
61
 
62
  # Extract the transcription from the ASR model output
63
  predicted_text = normalizer(asr_output["text"])
64
  predicted_text_turbo = normalizer(asr_output_turbo["text"])
 
65
 
66
  # Compute WER, Word Accuracy, and CER
67
  wer_score = wer(reference_text, predicted_text)
@@ -70,52 +78,74 @@ def generate_audio(index = None):
70
  wer_score_turbo = wer(reference_text, predicted_text_turbo)
71
  cer_score_turbo = cer(reference_text, predicted_text_turbo)
72
 
 
 
 
73
  # Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
74
  sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
75
 
76
- return ((
77
- audio["sampling_rate"],
78
- audio["array"]
79
- ), sentence_info, predicted_text, wer_score, cer_score, predicted_text_turbo,
80
- wer_score_turbo, cer_score_turbo)
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def update_ui():
83
  res = []
84
  for i in range(4):
85
- res.append(gr.Textbox(label=f"Label {i}"))
86
  return res
87
 
88
- with (gr.Blocks() as demo):
 
89
  gr.HTML("""
90
  <h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
91
  gr.Markdown("""
92
  This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
93
  arabic dataset from mozilla-foundation/common_voice_11_0
94
  """)
95
- num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
96
  generate_button = gr.Button("Generate Samples")
97
 
98
 
99
- @gr.render(inputs=num_samples_input, triggers=[generate_button.click])
100
  def render(num_samples):
101
  with gr.Column():
102
  for i in range(num_samples):
103
  # Generate audio and associated data
104
- _audio, label, asr_text, wer_score, cer_score, asr_text_turbo, wer_score_turbo, cer_score_turbo =generate_audio()
105
 
106
  # Create Gradio components to display the audio, transcription, and metrics
107
- gr.Audio(_audio, label = label)
108
  with gr.Row():
109
  with gr.Column():
110
- gr.Textbox(value = asr_text, label = "Whisper large output"),
111
- gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
112
- gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
 
 
 
 
 
 
113
  with gr.Column():
114
- gr.Textbox(value = asr_text_turbo, label = "Whisper large turbo output"),
115
- gr.Textbox(value = f"WER: {wer_score_turbo:.2f}", label = "Word Error Rate - "
116
- "TURBO "),
117
- gr.Textbox(value = f"CER: {cer_score_turbo:.2f}", label = "Character Error "
118
- "Rate - TURBO")
119
-
120
- if __name__ == '__main__':
121
- demo.launch(show_error = True)
 
 
 
1
  import gradio as gr
2
  from datasets import Audio
3
  from datasets import load_dataset
 
7
  from arabic_normalizer import ArabicTextNormalizer
8
 
9
  # Load dataset
10
+ common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar",
11
+ split = "train")
12
  # select column that will be used
13
  common_voice = common_voice.select_columns(["audio", "sentence"])
14
 
 
20
  asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
21
  generate_kwargs = generate_kwargs)
22
  asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
23
+ generate_kwargs = generate_kwargs)
24
+ asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model =
25
+ "mboushaba/whisper-large-v3-turbo-arabic",
26
+ generate_kwargs = generate_kwargs)
27
  normalizer = ArabicTextNormalizer()
28
 
29
 
 
56
  "sampling_rate": audio["sampling_rate"]
57
  }
58
 
59
+ audio_data_turbo_mboushaba = {
60
+ "raw": audio["array"],
61
+ "sampling_rate": audio["sampling_rate"]
62
+ }
63
+
64
  # Perform automatic speech recognition (ASR) directly on the resampled audio array
65
  asr_output = asr_whisper_large(audio_data)
 
66
  asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
67
+ asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba)
68
 
69
  # Extract the transcription from the ASR model output
70
  predicted_text = normalizer(asr_output["text"])
71
  predicted_text_turbo = normalizer(asr_output_turbo["text"])
72
+ predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"])
73
 
74
  # Compute WER, Word Accuracy, and CER
75
  wer_score = wer(reference_text, predicted_text)
 
78
  wer_score_turbo = wer(reference_text, predicted_text_turbo)
79
  cer_score_turbo = cer(reference_text, predicted_text_turbo)
80
 
81
+ wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba)
82
+ cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba)
83
+
84
  # Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
85
  sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
86
 
87
+ return {
88
+ "audio": (
89
+ audio["sampling_rate"],
90
+ audio["array"]
91
+ ),
92
+ "sentence_info": sentence_info,
93
+ "predicted_text": predicted_text,
94
+ "wer_score": wer_score,
95
+ "cer_score": cer_score,
96
+ "predicted_text_turbo": predicted_text_turbo,
97
+ "wer_score_turbo": wer_score_turbo,
98
+ "cer_score_turbo": cer_score_turbo,
99
+ "predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba,
100
+ "wer_score_turbo_mboushaba": wer_score_turbo_mboushaba,
101
+ "cer_score_turbo_mboushaba": cer_score_turbo_mboushaba
102
+ }
103
+
104
 
105
  def update_ui():
106
  res = []
107
  for i in range(4):
108
+ res.append(gr.Textbox(label = f"Label {i}"))
109
  return res
110
 
111
+
112
+ with gr.Blocks() as demo:
113
  gr.HTML("""
114
  <h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
115
  gr.Markdown("""
116
  This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
117
  arabic dataset from mozilla-foundation/common_voice_11_0
118
  """)
119
+ num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples")
120
  generate_button = gr.Button("Generate Samples")
121
 
122
 
123
+ @gr.render(inputs = num_samples_input, triggers = [generate_button.click])
124
  def render(num_samples):
125
  with gr.Column():
126
  for i in range(num_samples):
127
  # Generate audio and associated data
128
+ data = generate_audio()
129
 
130
  # Create Gradio components to display the audio, transcription, and metrics
131
+ gr.Audio(data["audio"], label = data["sentence_info"])
132
  with gr.Row():
133
  with gr.Column():
134
+ gr.Textbox(value = data["predicted_text"], label = "Whisper large output"),
135
+ gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"),
136
+ gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"),
137
+ with gr.Column():
138
+ gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"),
139
+ gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - "
140
+ "TURBO "),
141
+ gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error "
142
+ "Rate - TURBO")
143
  with gr.Column():
144
+ gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo "
145
+ "mboushaba output"),
146
+ gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - "
147
+ " mboushaba TURBO "),
148
+ gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error "
149
+ "Rate - mboushaba TURBO")
150
+
151
+ demo.launch(show_error = True)