File size: 4,870 Bytes
740d8bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torchaudio
from datasets import load_dataset, load_metric, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ForCTC, AutoModelForCTC, Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer
import numpy
import re
import sys
import random


# decide if lm should be used for decoding or not via command line
do_lm = bool(int(sys.argv[1]))
# set the number of random examples to be shown via command line
n_elements = int(sys.argv[2])
#eval_size = int(sys.argv[3])
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Decoding with language model\n") if do_lm else print("Decoding without language model\n")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")

# Empty cache
torch.cuda.empty_cache()

# set devide
device = "cuda" if torch.cuda.is_available() else "cpu"

# load dataset
common_voice_test = load_dataset("mozilla-foundation/common_voice_7_0", "gl", split="test")
#common_voice_test = load_dataset("mozilla-foundation/common_voice_7_0", "gl", split="test[:1%]")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Common Voice test dataset:\n")
print(common_voice_test)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print("Number of elements in Common Voice test dataset:", common_voice_test.num_rows, "\n")

# load metric
# the predominant metric in ASR is the word error rate (WER)
wer = load_metric("wer")
cer = load_metric("cer")

# Chars to be removed
chars_to_remove_regex = '[^A-Za-záéíóúñüÁÉÍÓÚÑÜ\- ]'
#chars_to_remove_regex = '[\,\¿\?\.\¡\!\;\:\"\n\t()\{\}\[\]]'

# load model and processor
model_path = "./"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path, eos_token=None, bos_token=None) if do_lm else Wav2Vec2Processor.from_pretrained(model_path)
model = AutoModelForCTC.from_pretrained(model_path).to(device)

# Remove special characters and lowcase normalization
def remove_special_characters(batch):
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
    return batch

# Preprocessing the dataset
def prepare_dataset(batch):
    # batched output is "un-batched"
    audio = batch["audio"]
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])

    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

# Evaluation of the model
def evaluate(batch):
    inputs = processor(batch["input_values"], sampling_rate=16_000, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        #logits = model(inputs.input_values.to(device), attention_mask=inputs.attention_mask.to(device)).logits
        logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits

    if do_lm:
        # batch["pred_strings"] = processor.batch_decode(logits.detach().numpy()).text
         batch["pred_strings"] = processor.batch_decode(logits.cpu().numpy()).text
    else:
        pred_ids = torch.argmax(logits, dim=-1)
        batch["pred_strings"] = processor.batch_decode(pred_ids)
    
    return batch

# Show N random elements of the dataset
def show_random_elements(dataset, num_examples):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    # Print headings
    print(f"\n{'Id':<4}{'File':<14}{'P':<3}{'N':<3}{'Sentence':<95}{'Prediction':<95}\n")
    # Pring data
    for i in range(0,num_examples):
        row = picks[i]
        path = dataset[row]["path"][-12:]
        up_votes = dataset[row]["up_votes"]
        down_votes = dataset[row]["down_votes"]
        reference = dataset[row]["sentence"]
        prediction = dataset[row]["pred_strings"]
        print(f"{i:<4}{path:<14}{up_votes:<3}{down_votes:<3}{reference:<95}{prediction:<95}")

# Remove special characters and loowcase normalization
test_dataset = common_voice_test.map(remove_special_characters)

# resampling to 16KHz
test_dataset = test_dataset.cast_column("audio", Audio(sampling_rate=16_000))

# Prepare dataset
test_dataset = test_dataset.map(prepare_dataset)

# Evaluate dataset
result = test_dataset.map(evaluate, batched=True, batch_size=8)

print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print(f"Showing {n_elements} random elementes:\n")
show_random_elements(result, n_elements)


print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print("WER: {:2f}".format(100 * wer.compute(references=result["sentence"], predictions=result["pred_strings"])))
print("CER: {:2f}".format(100 * cer.compute(references=result["sentence"], predictions=result["pred_strings"])))
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")