wav2vec2-base-turkish / run_evaluation.py
cahya's picture
remove unigrams
2cf743a
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, set_seed
import argparse
from pyctcdecode import build_ctcdecoder
from multiprocessing import Pool
import re
"""
This is the script to run the prediction on test set of Turkish speech dataset.
Usage:
python run_evaluation.y -m <wav2vec2 model_name> -d <Zindi dataset directory> -o <output file name> \
-b <optional batch size, default=16>
"""
class KenLM:
def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128):
self.num_workers = num_workers
self.beam_width = beam_width
vocab_dict = tokenizer.get_vocab()
self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)]
self.vocabulary = self.vocabulary[:-1]
self.decoder = build_ctcdecoder(self.vocabulary, model_name)
@staticmethod
def lm_postprocess(text):
return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip()
def decode(self, logits):
probs = logits.cpu().numpy()
# probs = logits.numpy()
with Pool(self.num_workers) as pool:
text = self.decoder.decode_batch(pool, probs)
text = [KenLM.lm_postprocess(x) for x in text]
return text
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\'\`…\’»«]'
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model_name", type=str, required=True,
help="The wav2vec2 model name")
parser.add_argument("-n", "--name", type=str, required=True,
help="The name of dataset")
parser.add_argument("-c", "--config_name", type=str, required=True,
help="The config name of the dataset")
parser.add_argument("-d", "--data_dir", type=str, required=False, default=None,
help="The directory contains the dataset")
parser.add_argument("-b", "--batch_size", type=int, required=False, default=16,
help="Batch size")
parser.add_argument("-k", "--kenlm", type=str, required=False, default=False,
help="Path to KenLM model")
parser.add_argument("-u", "--unigrams", type=str, required=False, default=None,
help="Path to unigrams file")
parser.add_argument("--num_workers", type=int, required=False, default=8,
help="KenLM's number of workers")
parser.add_argument("-w", "--beam_width", type=int, required=False, default=128,
help="KenLM's beam width")
parser.add_argument("--test_pct", type=int, required=False, default=100,
help="Percentage of the test set")
parser.add_argument("--cpu", required=False, action='store_true',
help="Force to use CPU")
args = parser.parse_args()
set_seed(42) # set the random seed to have reproducible result.
processor = Wav2Vec2Processor.from_pretrained(args.model_name)
model = Wav2Vec2ForCTC.from_pretrained(args.model_name)
kenlm = None
if args.kenlm:
kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
if "audio" in batch:
speech_array = torch.tensor(batch["audio"]["array"])
resampler = torchaudio.transforms.Resample(48_000, 16_000)
else:
speech_array, sampling_rate = torchaudio.load(batch["path"])
resampler = torchaudio.transforms.Resample(sampling_rate, 16_000)
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
def remove_special_characters(batch):
batch["norm_text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower().strip()
return batch
lg_test = load_dataset(args.name, args.config_name, data_dir=args.data_dir,
split=f"test[:{args.test_pct}%]", use_auth_token=True)
if args.cpu:
device = "cpu"
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lg_test = lg_test.map(speech_file_to_array_fn)
lg_test = lg_test.map(remove_special_characters)
model = model.to(device)
wer = load_metric("wer")
batch_size = args.batch_size
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to(device)).logits
if args.kenlm:
batch["pred_strings"] = kenlm.decode(logits)
else:
predicted_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(predicted_ids)
return batch
result = lg_test.map(evaluate, batched=True, batch_size=batch_size)
WER = 100 * wer.compute(predictions=result["pred_strings"], references=result["norm_text"])
print(f"WER: {WER:.2f}%")
if __name__ == "__main__":
main()