Spaces:
Sleeping
Sleeping
# Import Libraries to load Models | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
# Import Libraries to access Datasets | |
from datasets import load_dataset | |
from datasets import Audio | |
# Helper Libraries | |
import plotly.graph_objs as go | |
import evaluate | |
import librosa | |
import torch | |
import numpy as np | |
import pandas as pd | |
import time | |
# This constant determines on how many samples the Models are evaluated on | |
N_SAMPLES = 50 | |
# Load the WER Metric | |
wer_metric = evaluate.load("wer") | |
def run(data_subset:str, model_1:str, model_2:str, own_audio, own_transcription:str): | |
""" | |
Main Function running an entire evaluation cycle | |
Params: | |
- data_subset (str) :The name of a valid Dataset to choose from ["Common Voice", "Librispeech ASR clean", "Librispeech ASR other", "OWN Recording/Sample"] | |
- model_1 (str) :The name of a valid model to choose form ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h","openai/whisper-large-v2"] | |
- model_2 (str) :The name of a valid model to choose form ["openai/whisper-tiny.en", "facebook/s2t-medium-librispeech-asr", "facebook/wav2vec2-base-960h","openai/whisper-large-v2"] | |
- own_audio (gr.Audio) :The return value of an gr.Audio component (sr, audio (as numpy array)) | |
- own_transcription (str) :The paired transcription to the own_audio | |
""" | |
# A little bit of Error Handling | |
if data_subset is None and own_audio is None and own_transcription is None: | |
raise ValueError("No Dataset selected") | |
if model_1 is None: | |
raise ValueError("No Model 1 selected") | |
if model_2 is None: | |
raise ValueError("No Model 2 selected") | |
# Load the selected Dataset but only N_SAMPLES of it | |
if data_subset == "Common Voice": | |
dataset, text_column = load_Common_Voice() | |
elif data_subset == "Librispeech ASR clean": | |
dataset, text_column = load_Librispeech_ASR_clean() | |
elif data_subset == "Librispeech ASR other": | |
dataset, text_column = load_Librispeech_ASR_other() | |
elif data_subset == "OWN Recording/Sample": | |
sr, audio = own_audio | |
audio = audio.astype(np.float32) | |
print("AUDIO: ", type(audio), audio) | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) | |
else: | |
# if data_subset is None then still load load_Common_Voice | |
dataset, text_column = load_Common_Voice() | |
# I have left the print statements because users have access to the logs in Spaces and this might help to understand what's going on | |
print("Dataset Loaded") | |
# Load the selected Models | |
model1, processor1 = load_model(model_1) | |
model2, processor2 = load_model(model_2) | |
print("Models Loaded") | |
# In case a own Recording is selected only a single sample has to be evaluated | |
if data_subset == "OWN Recording/Sample": | |
sample = {"audio":{"array":audio,"sampling_rate":16000}} | |
inference_times1 = [] | |
inference_times2 = [] | |
time_start = time.time() | |
transcription1 = model_compute(model1, processor1, sample, model_1) | |
time_stop = time.time() | |
duration = time_stop - time_start | |
inference_times1.append(duration) | |
time_start = time.time() | |
transcription2 = model_compute(model2, processor2, sample, model_2) | |
time_stop = time.time() | |
duration = time_stop - time_start | |
inference_times2.append(duration) | |
transcriptions1 = [transcription1] | |
transcriptions2 = [transcription2] | |
references = [own_transcription.lower()] | |
wer1 = round(N_SAMPLES * compute_wer(references, transcriptions1), 2) | |
wer2 = round(N_SAMPLES * compute_wer(references, transcriptions2), 2) | |
results_md = f""" | |
#### {model_1} | |
- WER Score: {wer1} | |
- Avg. Inference Duration: {round(sum(inference_times1)/len(inference_times1), 4)}s | |
#### {model_2} | |
- WER Score: {wer2} | |
- Avg. Inference Duration: {round(sum(inference_times2)/len(inference_times2), 4)}s""" | |
# Create the bar plot | |
fig = go.Figure( | |
data=[ | |
go.Bar(x=[f"{model_1}"], y=[wer1], showlegend=False), | |
go.Bar(x=[f"{model_2}"], y=[wer2], showlegend=False), | |
] | |
) | |
# Update the layout for better visualization | |
fig.update_layout( | |
title="Comparison of Two Models", | |
xaxis_title="Models", | |
yaxis_title="Value", | |
barmode="group", | |
) | |
df = pd.DataFrame({"references":references, "transcriptions 1":transcriptions1,"WER 1":[wer1],"transcriptions 2":transcriptions2,"WER 2":[wer2]}) | |
yield results_md, fig, df | |
# In case a Dataset has been selected | |
else: | |
references = [] | |
transcriptions1 = [] | |
transcriptions2 = [] | |
WER1s = [] | |
WER2s = [] | |
inference_times1 = [] | |
inference_times2 = [] | |
counter = 0 | |
for i, sample in enumerate(dataset, start=1): | |
print(counter) | |
counter += 1 | |
references.append(sample[text_column]) | |
if model_1 == model_2: | |
time_start = time.time() | |
transcription = model_compute(model1, processor1, sample, model_1) | |
time_stop = time.time() | |
duration = time_stop - time_start | |
inference_times1.append(duration) | |
inference_times2.append(duration) | |
transcriptions1.append(transcription) | |
transcriptions2.append(transcription) | |
else: | |
time_start = time.time() | |
transcription1 = model_compute(model1, processor1, sample, model_1) | |
time_stop = time.time() | |
duration = time_stop - time_start | |
inference_times1.append(duration) | |
transcriptions1.append(transcription1) | |
time_start = time.time() | |
transcription2 = model_compute(model2, processor2, sample, model_2) | |
time_stop = time.time() | |
duration = time_stop - time_start | |
inference_times2.append(duration) | |
transcriptions2.append(transcription2) | |
WER1s.append(round(compute_wer([sample[text_column]], [transcription1]),4)) | |
WER2s.append(round(compute_wer([sample[text_column]], [transcription2]),4)) | |
wer1 = round(sum(WER1s)/len(WER1s), 4) | |
wer2 = round(sum(WER2s)/len(WER2s), 4) | |
results_md = f""" | |
{i}/{len(dataset)}-{'#'*i}{'_'*(N_SAMPLES-i)} | |
#### {model_1} | |
- WER Score: {wer1} | |
- Avg. Inference Duration: {round(sum(inference_times1)/len(inference_times1), 4)}s | |
#### {model_2} | |
- WER Score: {wer2} | |
- Avg. Inference Duration: {round(sum(inference_times2)/len(inference_times2), 4)}s""" | |
# Create the bar plot | |
fig = go.Figure( | |
data=[ | |
go.Bar(x=[f"{model_1}"], y=[wer1], showlegend=False), | |
go.Bar(x=[f"{model_2}"], y=[wer2], showlegend=False), | |
] | |
) | |
# Update the layout for better visualization | |
fig.update_layout( | |
title="Comparison of Two Models", | |
xaxis_title="Models", | |
yaxis_title="Value", | |
barmode="group", | |
) | |
df = pd.DataFrame({"references":references, f"{model_1}":transcriptions1,"WER 1":WER1s,f"{model_2}":transcriptions2,"WER 2":WER2s}) | |
yield results_md, fig, df | |
# DATASET LOADERS | |
def load_Common_Voice(): | |
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, token=True, trust_remote_code=True) | |
text_column = "sentence" | |
dataset = dataset.take(N_SAMPLES) | |
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
dataset = list(dataset) | |
for sample in dataset: | |
sample["text"] = sample["text"].lower() | |
return dataset, text_column | |
def load_Librispeech_ASR_clean(): | |
dataset = load_dataset("librispeech_asr", "clean", split="test", streaming=True, token=True, trust_remote_code=True) | |
print(next(iter(dataset))) | |
text_column = "text" | |
dataset = dataset.take(N_SAMPLES) | |
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
dataset = list(dataset) | |
for sample in dataset: | |
sample["text"] = sample["text"].lower() | |
return dataset, text_column | |
def load_Librispeech_ASR_other(): | |
dataset = load_dataset("librispeech_asr", "other", split="test", streaming=True, token=True, trust_remote_code=True) | |
print(next(iter(dataset))) | |
text_column = "text" | |
dataset = dataset.take(N_SAMPLES) | |
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) | |
dataset = list(dataset) | |
for sample in dataset: | |
sample["text"] = sample["text"].lower() | |
return dataset, text_column | |
# MODEL LOADERS | |
def load_model(model_id:str): | |
if model_id == "openai/whisper-tiny.en": | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | |
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |
elif model_id == "facebook/s2t-medium-librispeech-asr": | |
model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-librispeech-asr") | |
processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-librispeech-asr") | |
elif model_id == "facebook/wav2vec2-base-960h": | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
elif model_id == "openai/whisper-large-v2": | |
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2") | |
model.config.forced_decoder_ids = None | |
else: # In case no model has been selected the Whipser-Tiny.En is selected - just for completeness | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") | |
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") | |
return model, processor | |
# MODEL INFERENCE | |
def model_compute(model, processor, sample, model_id): | |
if model_id == "openai/whisper-tiny.en": | |
sample = sample["audio"] | |
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
transcription = processor.tokenizer.normalize(transcription[0]) | |
return transcription | |
elif model_id == "facebook/s2t-medium-librispeech-asr": | |
sample = sample["audio"] | |
features = processor(sample["array"], sampling_rate=16000, padding=True, return_tensors="pt") | |
input_features = features.input_features | |
attention_mask = features.attention_mask | |
gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask) | |
transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True) | |
return transcription[0] | |
elif model_id == "facebook/wav2vec2-base-960h": | |
sample = sample["audio"] | |
input_values = processor(sample["array"], sampling_rate=16000, return_tensors="pt", padding="longest").input_values # Batch size 1 | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return transcription[0].lower() | |
elif model_id == "openai/whisper-large-v2": | |
sample = sample["audio"] | |
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
transcription = processor.tokenizer.normalize(transcription[0]) | |
print("TRANSCRIPTION Whisper Large v2: ", transcription) | |
return transcription | |
else: # In case no model has been selected the Whipser-Tiny.En is selected - just for completeness | |
sample = sample["audio"] | |
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
return transcription[0] | |
# UTILS | |
def compute_wer(references, predictions): | |
wer = wer_metric.compute(references=references, predictions=predictions) | |
return wer | |