File size: 5,162 Bytes
133f88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f551c27
133f88a
 
 
 
d2617b7
2cf743a
133f88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b878951
f551c27
133f88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f551c27
133f88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, set_seed
import argparse
from pyctcdecode import build_ctcdecoder
from multiprocessing import Pool
import re

"""
This is the script to run the prediction on test set of Turkish speech dataset.
Usage:
python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o <output file name> \
    -b <optional batch size, default=16>
"""


class KenLM:
    def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128):
        self.num_workers = num_workers
        self.beam_width = beam_width
        vocab_dict = tokenizer.get_vocab()
        self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
        self.vocabulary = self.vocabulary[:-1]
        self.decoder = build_ctcdecoder(self.vocabulary, model_name)

    @staticmethod
    def lm_postprocess(text):
        return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()

    def decode(self, logits):
        probs = logits.cpu().numpy()
        # probs = logits.numpy()
        with Pool(self.num_workers) as pool:
            text = self.decoder.decode_batch(pool, probs)
            text = [KenLM.lm_postprocess(x) for x in text]
        return text

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\'\`…\’»«]'

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model_name", type=str, required=True,
                        help="The wav2vec2 model name")
    parser.add_argument("-n", "--name", type=str, required=True,
                        help="The name of dataset")
    parser.add_argument("-c", "--config_name", type=str, required=True,
                        help="The config name of the dataset")
    parser.add_argument("-d", "--data_dir", type=str, required=False, default=None,
                        help="The directory contains the dataset")
    parser.add_argument("-b", "--batch_size", type=int, required=False, default=16,
                        help="Batch size")
    parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
                        help="Path to KenLM model")
    parser.add_argument("-u", "--unigrams", type=str, required=False, default=None,
                        help="Path to unigrams file")
    parser.add_argument("--num_workers", type=int, required=False, default=8,
                        help="KenLM's number of workers")
    parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
                        help="KenLM's beam width")
    parser.add_argument("--test_pct", type=int, required=False, default=100,
                        help="Percentage of the test set")
    parser.add_argument("--cpu", required=False, action='store_true',
                        help="Force to use CPU")
    args = parser.parse_args()

    set_seed(42)  # set the random seed to have reproducible result.
    processor = Wav2Vec2Processor.from_pretrained(args.model_name)
    model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
    kenlm = None
    if args.kenlm:
        kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams)

    # Preprocessing the datasets.
    # We need to read the audio files as arrays

    def speech_file_to_array_fn(batch):
        if "audio" in batch:
            speech_array = torch.tensor(batch["audio"]["array"])
            resampler = torchaudio.transforms.Resample(48_000, 16_000)
        else:
            speech_array, sampling_rate = torchaudio.load(batch["path"])
            resampler = torchaudio.transforms.Resample(sampling_rate, 16_000)
        batch["speech"] = resampler(speech_array).squeeze().numpy()
        return batch

    def remove_special_characters(batch):
        batch["norm_text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower().strip()
        return batch

    lg_test = load_dataset(args.name, args.config_name, data_dir=args.data_dir,
                           split=f"test[:{args.test_pct}%]", use_auth_token=True)
    if args.cpu:
        device = "cpu"
    else:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lg_test = lg_test.map(speech_file_to_array_fn)
    lg_test = lg_test.map(remove_special_characters)
    model = model.to(device)
    wer = load_metric("wer")

    batch_size = args.batch_size
    def evaluate(batch):
        inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

        with torch.no_grad():
            logits = model(inputs.input_values.to(device)).logits

        if args.kenlm:
            batch["pred_strings"] = kenlm.decode(logits)
        else:
            predicted_ids = torch.argmax(logits, dim=-1)
            batch["pred_strings"] = processor.batch_decode(predicted_ids)
        return batch

    result = lg_test.map(evaluate, batched=True, batch_size=batch_size)
    WER = 100 * wer.compute(predictions=result["pred_strings"], references=result["norm_text"])
    print(f"WER: {WER:.2f}%")


if __name__ == "__main__":
    main()