import torch import torchaudio from datasets import load_dataset, load_metric, Audio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer import numpy import re import sys import random import pandas as pd # decide if lm should be used for decoding or not via command line do_lm = bool(int(sys.argv[1])) # set the number of random examples to be shown via command line n_elements = int(sys.argv[2]) #eval_size = int(sys.argv[2]) print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") print("Decoding with language model\n") if do_lm else print("Decoding without language model\n") print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") # Empty cache torch.cuda.empty_cache() # set devide device = "cuda" if torch.cuda.is_available() else "cpu" # load dataset #test_dataset = load_dataset("openslr", "SLR77", split="train[:1%]") slr77_test = load_dataset("json", data_files='../xlsr-fine-tuning-gl/elra_test_manifest2.json') print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") print("SLR77 test:\n") print(slr77_test) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") print("Number of elements in SLR77 test dataset:", slr77_test["train"].num_rows, "\n") # load metric # the predominant metric in ASR is the word error rate (WER) wer = load_metric("wer") cer = load_metric("cer") # Chars to be removed chars_to_remove_regex = '[^A-Za-záéíóúñüÁÉÍÓÚÑÜ\- ]' #chars_to_remove_regex = '[\,\¿\?\.\¡\!\;\:\"\n\t()\{\}\[\]]' # load model and processor model_path = "./" processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path, eos_token=None, bos_token=None) if do_lm else Wav2Vec2Processor.from_pretrained(model_path) model = AutoModelForCTC.from_pretrained(model_path).to(device) resampler = torchaudio.transforms.Resample(48_000, 16_000) # Remove special characters and loowcase normalization def remove_special_characters(batch): batch["text"] = re.sub(chars_to_remove_regex, '', batch["text"]).lower() return batch # Preprocessing the datasets. # We need to read the audio files as arrays def prepare_dataset(batch): # batched output is "un-batched" speech_array, sampling_rate = torchaudio.load(batch["audio_filepath"]) # resampling to 16KHz batch["speech"] = resampler(speech_array).squeeze().numpy() return batch # Evaluation of the model. def evaluate(batch): inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True).to(device) with torch.no_grad(): logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits if do_lm: # batch["pred_strings"] = processor.batch_decode(logits.detach().numpy()) batch["pred_strings"] = processor.batch_decode(logits.cpu().numpy()).text else: pred_ids = torch.argmax(logits, dim=-1) batch["pred_strings"] = processor.batch_decode(pred_ids) return batch # Show N random elements of the dataset def show_random_elements(dataset, num_examples): assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." picks = [] for _ in range(num_examples): pick = random.randint(0, len(dataset)-1) while pick in picks: pick = random.randint(0, len(dataset)-1) picks.append(pick) #picks = [74, 77, 66, 682, 556, 603, 394, 420, 384, 789, 735, 696, 6, 294, 497, 421] # Print headings print(f"\n{'Row':<4}{'File':<28}{'Sentence':<105}{'Prediction':<105}\n") # Pring data for i in range(0,num_examples): row = picks[i] path = dataset[row]["audio_filepath"][-25:] reference = dataset[row]["text"] prediction = dataset[row]["pred_strings"] print(f"{row:<4}{path:<28}{reference:<105}{prediction:<105}") # Remove special characters and loowcase normalization test_dataset = slr77_test.map(remove_special_characters) # Prepare dataset test_dataset = test_dataset.map(prepare_dataset) result = test_dataset.map(evaluate, batched=True, batch_size=8) print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") print(f"Showing {n_elements} random elementes:\n") show_random_elements(result["train"], n_elements) print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") print("WER: {:2f}".format(100 * wer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"]))) print("CER: {:2f}".format(100 * cer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"]))) print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")