Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
import gradio as gr
|
4 |
from datasets import Audio
|
5 |
from datasets import load_dataset
|
@@ -9,7 +7,8 @@ from transformers import pipeline
|
|
9 |
from arabic_normalizer import ArabicTextNormalizer
|
10 |
|
11 |
# Load dataset
|
12 |
-
common_voice = load_dataset("mozilla-foundation/common_voice_11_0",trust_remote_code=True, name = "ar",
|
|
|
13 |
# select column that will be used
|
14 |
common_voice = common_voice.select_columns(["audio", "sentence"])
|
15 |
|
@@ -21,7 +20,10 @@ generate_kwargs = {
|
|
21 |
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
|
22 |
generate_kwargs = generate_kwargs)
|
23 |
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
|
24 |
-
|
|
|
|
|
|
|
25 |
normalizer = ArabicTextNormalizer()
|
26 |
|
27 |
|
@@ -54,14 +56,20 @@ def generate_audio(index = None):
|
|
54 |
"sampling_rate": audio["sampling_rate"]
|
55 |
}
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
# Perform automatic speech recognition (ASR) directly on the resampled audio array
|
58 |
asr_output = asr_whisper_large(audio_data)
|
59 |
-
|
60 |
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
|
|
|
61 |
|
62 |
# Extract the transcription from the ASR model output
|
63 |
predicted_text = normalizer(asr_output["text"])
|
64 |
predicted_text_turbo = normalizer(asr_output_turbo["text"])
|
|
|
65 |
|
66 |
# Compute WER, Word Accuracy, and CER
|
67 |
wer_score = wer(reference_text, predicted_text)
|
@@ -70,52 +78,74 @@ def generate_audio(index = None):
|
|
70 |
wer_score_turbo = wer(reference_text, predicted_text_turbo)
|
71 |
cer_score_turbo = cer(reference_text, predicted_text_turbo)
|
72 |
|
|
|
|
|
|
|
73 |
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
|
74 |
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
|
75 |
|
76 |
-
return
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def update_ui():
|
83 |
res = []
|
84 |
for i in range(4):
|
85 |
-
res.append(gr.Textbox(label=f"Label {i}"))
|
86 |
return res
|
87 |
|
88 |
-
|
|
|
89 |
gr.HTML("""
|
90 |
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
|
91 |
gr.Markdown("""
|
92 |
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
|
93 |
arabic dataset from mozilla-foundation/common_voice_11_0
|
94 |
""")
|
95 |
-
num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
|
96 |
generate_button = gr.Button("Generate Samples")
|
97 |
|
98 |
|
99 |
-
@gr.render(inputs=num_samples_input, triggers=[generate_button.click])
|
100 |
def render(num_samples):
|
101 |
with gr.Column():
|
102 |
for i in range(num_samples):
|
103 |
# Generate audio and associated data
|
104 |
-
|
105 |
|
106 |
# Create Gradio components to display the audio, transcription, and metrics
|
107 |
-
gr.Audio(
|
108 |
with gr.Row():
|
109 |
with gr.Column():
|
110 |
-
gr.Textbox(value =
|
111 |
-
gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
|
112 |
-
gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
with gr.Column():
|
114 |
-
gr.Textbox(value =
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from datasets import Audio
|
3 |
from datasets import load_dataset
|
|
|
7 |
from arabic_normalizer import ArabicTextNormalizer
|
8 |
|
9 |
# Load dataset
|
10 |
+
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar",
|
11 |
+
split = "train")
|
12 |
# select column that will be used
|
13 |
common_voice = common_voice.select_columns(["audio", "sentence"])
|
14 |
|
|
|
20 |
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3",
|
21 |
generate_kwargs = generate_kwargs)
|
22 |
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
|
23 |
+
generate_kwargs = generate_kwargs)
|
24 |
+
asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model =
|
25 |
+
"mboushaba/whisper-large-v3-turbo-arabic",
|
26 |
+
generate_kwargs = generate_kwargs)
|
27 |
normalizer = ArabicTextNormalizer()
|
28 |
|
29 |
|
|
|
56 |
"sampling_rate": audio["sampling_rate"]
|
57 |
}
|
58 |
|
59 |
+
audio_data_turbo_mboushaba = {
|
60 |
+
"raw": audio["array"],
|
61 |
+
"sampling_rate": audio["sampling_rate"]
|
62 |
+
}
|
63 |
+
|
64 |
# Perform automatic speech recognition (ASR) directly on the resampled audio array
|
65 |
asr_output = asr_whisper_large(audio_data)
|
|
|
66 |
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
|
67 |
+
asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba)
|
68 |
|
69 |
# Extract the transcription from the ASR model output
|
70 |
predicted_text = normalizer(asr_output["text"])
|
71 |
predicted_text_turbo = normalizer(asr_output_turbo["text"])
|
72 |
+
predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"])
|
73 |
|
74 |
# Compute WER, Word Accuracy, and CER
|
75 |
wer_score = wer(reference_text, predicted_text)
|
|
|
78 |
wer_score_turbo = wer(reference_text, predicted_text_turbo)
|
79 |
cer_score_turbo = cer(reference_text, predicted_text_turbo)
|
80 |
|
81 |
+
wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba)
|
82 |
+
cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba)
|
83 |
+
|
84 |
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
|
85 |
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
|
86 |
|
87 |
+
return {
|
88 |
+
"audio": (
|
89 |
+
audio["sampling_rate"],
|
90 |
+
audio["array"]
|
91 |
+
),
|
92 |
+
"sentence_info": sentence_info,
|
93 |
+
"predicted_text": predicted_text,
|
94 |
+
"wer_score": wer_score,
|
95 |
+
"cer_score": cer_score,
|
96 |
+
"predicted_text_turbo": predicted_text_turbo,
|
97 |
+
"wer_score_turbo": wer_score_turbo,
|
98 |
+
"cer_score_turbo": cer_score_turbo,
|
99 |
+
"predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba,
|
100 |
+
"wer_score_turbo_mboushaba": wer_score_turbo_mboushaba,
|
101 |
+
"cer_score_turbo_mboushaba": cer_score_turbo_mboushaba
|
102 |
+
}
|
103 |
+
|
104 |
|
105 |
def update_ui():
|
106 |
res = []
|
107 |
for i in range(4):
|
108 |
+
res.append(gr.Textbox(label = f"Label {i}"))
|
109 |
return res
|
110 |
|
111 |
+
|
112 |
+
with gr.Blocks() as demo:
|
113 |
gr.HTML("""
|
114 |
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
|
115 |
gr.Markdown("""
|
116 |
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
|
117 |
arabic dataset from mozilla-foundation/common_voice_11_0
|
118 |
""")
|
119 |
+
num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples")
|
120 |
generate_button = gr.Button("Generate Samples")
|
121 |
|
122 |
|
123 |
+
@gr.render(inputs = num_samples_input, triggers = [generate_button.click])
|
124 |
def render(num_samples):
|
125 |
with gr.Column():
|
126 |
for i in range(num_samples):
|
127 |
# Generate audio and associated data
|
128 |
+
data = generate_audio()
|
129 |
|
130 |
# Create Gradio components to display the audio, transcription, and metrics
|
131 |
+
gr.Audio(data["audio"], label = data["sentence_info"])
|
132 |
with gr.Row():
|
133 |
with gr.Column():
|
134 |
+
gr.Textbox(value = data["predicted_text"], label = "Whisper large output"),
|
135 |
+
gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"),
|
136 |
+
gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"),
|
137 |
+
with gr.Column():
|
138 |
+
gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"),
|
139 |
+
gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - "
|
140 |
+
"TURBO "),
|
141 |
+
gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error "
|
142 |
+
"Rate - TURBO")
|
143 |
with gr.Column():
|
144 |
+
gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo "
|
145 |
+
"mboushaba output"),
|
146 |
+
gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - "
|
147 |
+
" mboushaba TURBO "),
|
148 |
+
gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error "
|
149 |
+
"Rate - mboushaba TURBO")
|
150 |
+
|
151 |
+
demo.launch(show_error = True)
|