Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import Audio | |
from datasets import load_dataset | |
from jiwer import wer, cer | |
from transformers import pipeline | |
from arabic_normalizer import ArabicTextNormalizer | |
# Load dataset | |
common_voice = load_dataset("mozilla-foundation/common_voice_11_0", trust_remote_code = True, name = "ar", | |
split = "train") | |
# select column that will be used | |
common_voice = common_voice.select_columns(["audio", "sentence"]) | |
generate_kwargs = { | |
"language": "arabic", | |
"task": "transcribe" | |
} | |
# Initialize ASR pipeline | |
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", | |
generate_kwargs = generate_kwargs) | |
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo", | |
generate_kwargs = generate_kwargs) | |
asr_whisper_large_turbo_mboushaba = pipeline("automatic-speech-recognition", model = | |
"mboushaba/whisper-large-v3-turbo-arabic", | |
generate_kwargs = generate_kwargs) | |
normalizer = ArabicTextNormalizer() | |
def generate_audio(index = None): | |
"""Select an audio sample, resample if needed, and transcribe using ASR.""" | |
# inspect dataset | |
# print(common_voice) | |
# print(common_voice.features) | |
# resample audio using dataset function | |
global common_voice | |
common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000)) | |
# print(common_voice.features) | |
# Randomly shuffle the dataset and pick the first sample | |
example = common_voice.shuffle()[0] | |
audio = example["audio"] | |
# Ground truth transcription (for WER/CER calculations) | |
reference_text = normalizer(example["sentence"]) | |
# Prepare audio data for ASR | |
audio_data = { | |
"array": audio["array"], | |
"sampling_rate": audio["sampling_rate"] | |
} | |
audio_data_turbo = { | |
"raw": audio["array"], | |
"sampling_rate": audio["sampling_rate"] | |
} | |
audio_data_turbo_mboushaba = { | |
"raw": audio["array"], | |
"sampling_rate": audio["sampling_rate"] | |
} | |
# Perform automatic speech recognition (ASR) directly on the resampled audio array | |
asr_output = asr_whisper_large(audio_data) | |
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo) | |
asr_output_turbo_mboushaba = asr_whisper_large_turbo_mboushaba(audio_data_turbo_mboushaba) | |
# Extract the transcription from the ASR model output | |
predicted_text = normalizer(asr_output["text"]) | |
predicted_text_turbo = normalizer(asr_output_turbo["text"]) | |
predicted_text_turbo_mboushaba = normalizer(asr_output_turbo_mboushaba["text"]) | |
# Compute WER, Word Accuracy, and CER | |
wer_score = wer(reference_text, predicted_text) | |
cer_score = cer(reference_text, predicted_text) | |
wer_score_turbo = wer(reference_text, predicted_text_turbo) | |
cer_score_turbo = cer(reference_text, predicted_text_turbo) | |
wer_score_turbo_mboushaba = wer(reference_text, predicted_text_turbo_mboushaba) | |
cer_score_turbo_mboushaba = cer(reference_text, predicted_text_turbo_mboushaba) | |
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics | |
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])]) | |
return { | |
"audio": ( | |
audio["sampling_rate"], | |
audio["array"] | |
), | |
"sentence_info": sentence_info, | |
"predicted_text": predicted_text, | |
"wer_score": wer_score, | |
"cer_score": cer_score, | |
"predicted_text_turbo": predicted_text_turbo, | |
"wer_score_turbo": wer_score_turbo, | |
"cer_score_turbo": cer_score_turbo, | |
"predicted_text_turbo_mboushaba": predicted_text_turbo_mboushaba, | |
"wer_score_turbo_mboushaba": wer_score_turbo_mboushaba, | |
"cer_score_turbo_mboushaba": cer_score_turbo_mboushaba | |
} | |
def update_ui(): | |
res = [] | |
for i in range(4): | |
res.append(gr.Textbox(label = f"Label {i}")) | |
return res | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""") | |
gr.Markdown(""" | |
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using | |
arabic dataset from mozilla-foundation/common_voice_11_0 | |
""") | |
num_samples_input = gr.Slider(minimum = 1, maximum = 10, step = 1, value = 4, label = "Number of audio samples") | |
generate_button = gr.Button("Generate Samples") | |
def render(num_samples): | |
with gr.Column(): | |
for i in range(num_samples): | |
# Generate audio and associated data | |
data = generate_audio() | |
# Create Gradio components to display the audio, transcription, and metrics | |
gr.Audio(data["audio"], label = data["sentence_info"]) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Textbox(value = data["predicted_text"], label = "Whisper large output"), | |
gr.Textbox(value = f"WER: {data['wer_score']:.2f}", label = "Word Error Rate"), | |
gr.Textbox(value = f"CER: {data['cer_score']:.2f}", label = "Character Error Rate"), | |
with gr.Column(): | |
gr.Textbox(value = data["predicted_text_turbo"], label = "Whisper large turbo output"), | |
gr.Textbox(value = f"WER: {data['wer_score_turbo']:.2f}", label = "Word Error Rate - " | |
"TURBO "), | |
gr.Textbox(value = f"CER: {data['cer_score_turbo']:.2f}", label = "Character Error " | |
"Rate - TURBO") | |
with gr.Column(): | |
gr.Textbox(value = data["predicted_text_turbo_mboushaba"], label = "Whisper large turbo " | |
"mboushaba output"), | |
gr.Textbox(value = f"WER: {data['wer_score_turbo_mboushaba']:.2f}", label = "Word Error Rate - " | |
" mboushaba TURBO "), | |
gr.Textbox(value = f"CER: {data['cer_score_turbo_mboushaba']:.2f}", label = "Character Error " | |
"Rate - mboushaba TURBO") | |
demo.launch(show_error = True) | |