|
import torch |
|
import torchaudio |
|
from datasets import load_dataset, load_metric, Audio |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer |
|
import numpy |
|
import re |
|
import sys |
|
import random |
|
import pandas as pd |
|
|
|
|
|
do_lm = bool(int(sys.argv[1])) |
|
|
|
n_elements = int(sys.argv[2]) |
|
|
|
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") |
|
print("Decoding with language model\n") if do_lm else print("Decoding without language model\n") |
|
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
slr77_test = load_dataset("json", data_files='../xlsr-fine-tuning-gl/elra_test_manifest2.json') |
|
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") |
|
print("SLR77 test:\n") |
|
print(slr77_test) |
|
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") |
|
print("Number of elements in SLR77 test dataset:", slr77_test["train"].num_rows, "\n") |
|
|
|
|
|
|
|
wer = load_metric("wer") |
|
cer = load_metric("cer") |
|
|
|
|
|
chars_to_remove_regex = '[^A-Za-záéíóúñüÁÉÍÓÚÑÜ\- ]' |
|
|
|
|
|
|
|
model_path = "./" |
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path, eos_token=None, bos_token=None) if do_lm else Wav2Vec2Processor.from_pretrained(model_path) |
|
model = AutoModelForCTC.from_pretrained(model_path).to(device) |
|
|
|
resampler = torchaudio.transforms.Resample(48_000, 16_000) |
|
|
|
|
|
def remove_special_characters(batch): |
|
batch["text"] = re.sub(chars_to_remove_regex, '', batch["text"]).lower() |
|
return batch |
|
|
|
|
|
|
|
def prepare_dataset(batch): |
|
|
|
speech_array, sampling_rate = torchaudio.load(batch["audio_filepath"]) |
|
|
|
batch["speech"] = resampler(speech_array).squeeze().numpy() |
|
return batch |
|
|
|
|
|
def evaluate(batch): |
|
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True).to(device) |
|
with torch.no_grad(): |
|
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits |
|
|
|
if do_lm: |
|
|
|
batch["pred_strings"] = processor.batch_decode(logits.cpu().numpy()).text |
|
else: |
|
pred_ids = torch.argmax(logits, dim=-1) |
|
batch["pred_strings"] = processor.batch_decode(pred_ids) |
|
|
|
return batch |
|
|
|
|
|
def show_random_elements(dataset, num_examples): |
|
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." |
|
picks = [] |
|
for _ in range(num_examples): |
|
pick = random.randint(0, len(dataset)-1) |
|
while pick in picks: |
|
pick = random.randint(0, len(dataset)-1) |
|
picks.append(pick) |
|
|
|
|
|
|
|
print(f"\n{'Row':<4}{'File':<28}{'Sentence':<105}{'Prediction':<105}\n") |
|
|
|
for i in range(0,num_examples): |
|
row = picks[i] |
|
path = dataset[row]["audio_filepath"][-25:] |
|
reference = dataset[row]["text"] |
|
prediction = dataset[row]["pred_strings"] |
|
print(f"{row:<4}{path:<28}{reference:<105}{prediction:<105}") |
|
|
|
|
|
|
|
test_dataset = slr77_test.map(remove_special_characters) |
|
|
|
|
|
test_dataset = test_dataset.map(prepare_dataset) |
|
|
|
result = test_dataset.map(evaluate, batched=True, batch_size=8) |
|
|
|
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") |
|
print(f"Showing {n_elements} random elementes:\n") |
|
show_random_elements(result["train"], n_elements) |
|
|
|
|
|
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") |
|
print("WER: {:2f}".format(100 * wer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"]))) |
|
print("CER: {:2f}".format(100 * cer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"]))) |
|
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") |
|
|