whisper-small-ar / evaluate_models.py
malmarz's picture
Training in progress, step 1000
8b8ff67
import torch
import librosa
from datasets import load_dataset, Audio
from transformers import WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperForConditionalGeneration
from huggingface_hub import login
import argparse
from evaluate import load
my_parser = argparse.ArgumentParser()
# my_parser.add_argument("--pal", "-paths_as_labels", action="store_true")
my_parser.add_argument("--model_name", "-model_name", type=str, action="store", default = "openai/whisper-tiny")
my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store")
my_parser.add_argument("--dataset_name", "-dataset_name", type=str, action="store", default = "google/fleurs")
my_parser.add_argument("--split", "-split", type=str, action="store", default = "test")
my_parser.add_argument("--subset", "-subset", type=str, action="store")
args = my_parser.parse_args()
try:
login(args.hf_token)
except:
raise(f"Can't login please set --hf_token {args.hf_token}")
dataset_name = args.dataset_name
model_name = args.model_name
subset = args.subset
text_column = "sentence"
if dataset_name == "google/fleurs":
text_column = "transcription"
print(f"Evaluating {args.model_name} on {args.dataset_name} [{subset}]")
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
test_dataset = load_dataset(dataset_name, subset, split=args.split, use_auth_token=True)
processor = WhisperProcessor.from_pretrained(model_name, language="Arabic", task="transcribe")
tokenizer = WhisperTokenizer.from_pretrained(model_name, language="Arabic", task="transcribe")
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16000))
# Preprocessing the datasets.
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# encode target text to label ids
batch["labels"] = tokenizer(batch[text_column]).input_ids
return batch
test_dataset = test_dataset.map(prepare_dataset)
model = model.to("cuda")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language = "ar", task = "transcribe")
def map_to_result(batch):
with torch.no_grad():
input_values = torch.tensor(batch["input_features"], device="cuda").unsqueeze(0)
pred_ids = model.generate(input_values)
batch["pred_str"] = processor.batch_decode(pred_ids, skip_special_tokens = True)[0]
batch["text"] = processor.decode(batch["labels"], skip_special_tokens = True)
return batch
results = test_dataset.map(map_to_result)
wer = load("wer")
print("Test WER: {:.3f}".format(wer.compute(predictions=results["pred_str"], references=results["text"])))