File size: 4,556 Bytes
740d8bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torchaudio
from datasets import load_dataset, load_metric, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer
import numpy
import re
import sys
import random
import pandas as pd

# decide if lm should be used for decoding or not via command line
do_lm = bool(int(sys.argv[1]))
# set the number of random examples to be shown via command line
n_elements = int(sys.argv[2])
#eval_size = int(sys.argv[2])
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Decoding with language model\n") if do_lm else print("Decoding without language model\n")
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")

# Empty cache
torch.cuda.empty_cache()

# set devide
device = "cuda" if torch.cuda.is_available() else "cpu"

# load dataset
#test_dataset = load_dataset("openslr", "SLR77", split="train[:1%]")
slr77_test = load_dataset("json", data_files='../xlsr-fine-tuning-gl/elra_test_manifest2.json')
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("SLR77 test:\n")
print(slr77_test)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Number of elements in SLR77 test dataset:", slr77_test["train"].num_rows, "\n")

# load metric
# the predominant metric in ASR is the word error rate (WER)
wer = load_metric("wer")
cer = load_metric("cer")

# Chars to be removed
chars_to_remove_regex = '[^A-Za-záéíóúñüÁÉÍÓÚÑÜ\- ]'
#chars_to_remove_regex = '[\,\¿\?\.\¡\!\;\:\"\n\t()\{\}\[\]]'

# load model and processor
model_path = "./"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path, eos_token=None, bos_token=None) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path).to(device)

resampler = torchaudio.transforms.Resample(48_000, 16_000)

# Remove special characters and loowcase normalization
def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_remove_regex, '', batch["text"]).lower()
    return batch

# Preprocessing the datasets.
# We need to read the audio files as arrays
def prepare_dataset(batch):
    # batched output is "un-batched"
    speech_array, sampling_rate = torchaudio.load(batch["audio_filepath"])
    # resampling to 16KHz
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

# Evaluation of the model.
def evaluate(batch):
    inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits

    if do_lm:
        # batch["pred_strings"] = processor.batch_decode(logits.detach().numpy())
         batch["pred_strings"] = processor.batch_decode(logits.cpu().numpy()).text
    else:
        pred_ids = torch.argmax(logits, dim=-1)
        batch["pred_strings"] = processor.batch_decode(pred_ids)
    
    return batch

# Show N random elements of the dataset
def show_random_elements(dataset, num_examples):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    #picks = [74, 77, 66, 682, 556, 603, 394, 420, 384, 789, 735, 696, 6, 294, 497, 421]

    # Print headings
    print(f"\n{'Row':<4}{'File':<28}{'Sentence':<105}{'Prediction':<105}\n")
    # Pring data
    for i in range(0,num_examples):
        row = picks[i]
        path = dataset[row]["audio_filepath"][-25:]
        reference = dataset[row]["text"]
        prediction = dataset[row]["pred_strings"]
        print(f"{row:<4}{path:<28}{reference:<105}{prediction:<105}")


# Remove special characters and loowcase normalization
test_dataset = slr77_test.map(remove_special_characters)

# Prepare dataset
test_dataset = test_dataset.map(prepare_dataset)

result = test_dataset.map(evaluate, batched=True, batch_size=8)

print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print(f"Showing {n_elements} random elementes:\n")
show_random_elements(result["train"], n_elements)


print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print("WER: {:2f}".format(100 * wer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"])))
print("CER: {:2f}".format(100 * cer.compute(references=result["train"]["text"], predictions=result["train"]["pred_strings"])))
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")