cmagui's picture
First model version (wiki-gl LM)
740d8bb
import torch
import torchaudio
from datasets import load_dataset, load_metric, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer
import numpy
import re
import sys
import random
# decide if lm should be used for decoding or not via command line
do_lm = bool(int(sys.argv[1]))
# set the number of random examples to be shown via command line
n_elements = int(sys.argv[2])
#eval_size = int(sys.argv[3])
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Decoding with language model\n") if do_lm else print("Decoding without language model\n")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
# Empty cache
torch.cuda.empty_cache()
# set devide
device = "cuda" if torch.cuda.is_available() else "cpu"
# load dataset
common_voice_test = load_dataset("mozilla-foundation/common_voice_7_0", "gl", split="test")
#common_voice_test = load_dataset("mozilla-foundation/common_voice_7_0", "gl", split="test[:1%]")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Common Voice test dataset:\n")
print(common_voice_test)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Number of elements in Common Voice test dataset:", common_voice_test.num_rows, "\n")
# load metric
# the predominant metric in ASR is the word error rate (WER)
wer = load_metric("wer")
cer = load_metric("cer")
# Chars to be removed
chars_to_remove_regex = '[^A-Za-záéíóúñüÁÉÍÓÚÑÜ\- ]'
#chars_to_remove_regex = '[\,\¿\?\.\¡\!\;\:\"\n\t()\{\}\[\]]'
# load model and processor
model_path = "./"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path, eos_token=None, bos_token=None) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path).to(device)
# Remove special characters and lowcase normalization
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
# Preprocessing the dataset
def prepare_dataset(batch):
# batched output is "un-batched"
audio = batch["audio"]
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
# Evaluation of the model
def evaluate(batch):
inputs = processor(batch["input_values"], sampling_rate=16_000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
#logits = model(inputs.input_values.to(device), attention_mask=inputs.attention_mask.to(device)).logits
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
if do_lm:
# batch["pred_strings"] = processor.batch_decode(logits.detach().numpy()).text
batch["pred_strings"] = processor.batch_decode(logits.cpu().numpy()).text
else:
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
# Show N random elements of the dataset
def show_random_elements(dataset, num_examples):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
# Print headings
print(f"\n{'Id':<4}{'File':<14}{'P':<3}{'N':<3}{'Sentence':<95}{'Prediction':<95}\n")
# Pring data
for i in range(0,num_examples):
row = picks[i]
path = dataset[row]["path"][-12:]
up_votes = dataset[row]["up_votes"]
down_votes = dataset[row]["down_votes"]
reference = dataset[row]["sentence"]
prediction = dataset[row]["pred_strings"]
print(f"{i:<4}{path:<14}{up_votes:<3}{down_votes:<3}{reference:<95}{prediction:<95}")
# Remove special characters and loowcase normalization
test_dataset = common_voice_test.map(remove_special_characters)
# resampling to 16KHz
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))
# Prepare dataset
test_dataset = test_dataset.map(prepare_dataset)
# Evaluate dataset
result = test_dataset.map(evaluate, batched=True, batch_size=8)
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print(f"Showing {n_elements} random elementes:\n")
show_random_elements(result, n_elements)
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print("WER: {:2f}".format(100 * wer.compute(references=result["sentence"], predictions=result["pred_strings"])))
print("CER: {:2f}".format(100 * cer.compute(references=result["sentence"], predictions=result["pred_strings"])))
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")