import torch import torchaudio from datasets import load_dataset, load_metric from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, set_seed import argparse from pyctcdecode import build_ctcdecoder from multiprocessing import Pool import re """ This is the script to run the prediction on test set of Turkish speech dataset. Usage: python run_evaluation.y -m -d -o \ -b """ class KenLM: def __init__(self, tokenizer, model_name, unigrams=None, num_workers=8, beam_width=128): self.num_workers = num_workers self.beam_width = beam_width vocab_dict = tokenizer.get_vocab() self.vocabulary = [x[0] for x in sorted(vocab_dict.items(), key=lambda x: x[1], reverse=False)] self.vocabulary = self.vocabulary[:-1] self.decoder = build_ctcdecoder(self.vocabulary, model_name) @staticmethod def lm_postprocess(text): return ' '.join([x if len(x) > 1 else "" for x in text.split()]).strip() def decode(self, logits): probs = logits.cpu().numpy() # probs = logits.numpy() with Pool(self.num_workers) as pool: text = self.decoder.decode_batch(pool, probs) text = [KenLM.lm_postprocess(x) for x in text] return text chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\‘\”\'\`…\’»«]' def main(): parser = argparse.ArgumentParser() parser.add_argument("-m", "--model_name", type=str, required=True, help="The wav2vec2 model name") parser.add_argument("-n", "--name", type=str, required=True, help="The name of dataset") parser.add_argument("-c", "--config_name", type=str, required=True, help="The config name of the dataset") parser.add_argument("-d", "--data_dir", type=str, required=False, default=None, help="The directory contains the dataset") parser.add_argument("-b", "--batch_size", type=int, required=False, default=16, help="Batch size") parser.add_argument("-k", "--kenlm", type=str, required=False, default=False, help="Path to KenLM model") parser.add_argument("-u", "--unigrams", type=str, required=False, default=None, help="Path to unigrams file") parser.add_argument("--num_workers", type=int, required=False, default=8, help="KenLM's number of workers") parser.add_argument("-w", "--beam_width", type=int, required=False, default=128, help="KenLM's beam width") parser.add_argument("--test_pct", type=int, required=False, default=100, help="Percentage of the test set") parser.add_argument("--cpu", required=False, action='store_true', help="Force to use CPU") args = parser.parse_args() set_seed(42) # set the random seed to have reproducible result. processor = Wav2Vec2Processor.from_pretrained(args.model_name) model = Wav2Vec2ForCTC.from_pretrained(args.model_name) kenlm = None if args.kenlm: kenlm = KenLM(processor.tokenizer, args.kenlm, args.unigrams) # Preprocessing the datasets. # We need to read the audio files as arrays def speech_file_to_array_fn(batch): if "audio" in batch: speech_array = torch.tensor(batch["audio"]["array"]) resampler = torchaudio.transforms.Resample(48_000, 16_000) else: speech_array, sampling_rate = torchaudio.load(batch["path"]) resampler = torchaudio.transforms.Resample(sampling_rate, 16_000) batch["speech"] = resampler(speech_array).squeeze().numpy() return batch def remove_special_characters(batch): batch["norm_text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower().strip() return batch lg_test = load_dataset(args.name, args.config_name, data_dir=args.data_dir, split=f"test[:{args.test_pct}%]", use_auth_token=True) if args.cpu: device = "cpu" else: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') lg_test = lg_test.map(speech_file_to_array_fn) lg_test = lg_test.map(remove_special_characters) model = model.to(device) wer = load_metric("wer") batch_size = args.batch_size def evaluate(batch): inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(inputs.input_values.to(device)).logits if args.kenlm: batch["pred_strings"] = kenlm.decode(logits) else: predicted_ids = torch.argmax(logits, dim=-1) batch["pred_strings"] = processor.batch_decode(predicted_ids) return batch result = lg_test.map(evaluate, batched=True, batch_size=batch_size) WER = 100 * wer.compute(predictions=result["pred_strings"], references=result["norm_text"]) print(f"WER: {WER:.2f}%") if __name__ == "__main__": main()