#!/usr/bin/env python3 import sys import torch import re from datasets import load_dataset, load_metric from transformers import Wav2Vec2Processor, AutoModelForCTC from transformers.models.wav2vec2.processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM import torchaudio.functional as F import torch # decide if lm should be used for decoding or not via command line do_lm = bool(int(sys.argv[1])) eval_size = int(sys.argv[2]) device = "cuda" if torch.cuda.is_available() else "cpu" model_path = "./" wer = load_metric("wer") cer = load_metric("cer") # load model and processor processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_path) if do_lm else Wav2Vec2Processor.from_pretrained(model_path) model = AutoModelForCTC.from_pretrained(model_path).to(device) ds = load_dataset("common_voice", "es", split="test", streaming=True) ds_iter = iter(ds) references = [] predictions = [] CHARS_TO_IGNORE = [",", "?", "¿", ".", "!", "¡", ";", ";", ":", '""', "%", '"', "�", "ʿ", "·", "჻", "~", "՞", "؟", "،", "।", "॥", "«", "»", "„", "“", "”", "「", "」", "‘", "’", "《", "》", "(", ")", "[", "]", "{", "}", "=", "`", "_", "+", "<", ">", "…", "–", "°", "´", "ʾ", "‹", "›", "©", "®", "—", "→", "。", "、", "﹂", "﹁", "‧", "~", "﹏", ",", "{", "}", "(", ")", "[", "]", "【", "】", "‥", "〽", "『", "』", "〝", "〟", "⟨", "⟩", "〜", ":", "!", "?", "♪", "؛", "/", "\\", "º", "−", "^", "ʻ", "ˆ"] chars_to_ignore_regex = f"[{re.escape(''.join(CHARS_TO_IGNORE))}]" for _ in range(eval_size): sample = next(ds_iter) resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy() input_values = processor(resampled_audio, return_tensors="pt", sampling_rate=16_000).input_values with torch.no_grad(): logits = model(input_values.to(device)).logits.cpu() if do_lm: output_str = processor.batch_decode(logits)[0].lower() else: pred_ids = torch.argmax(logits, dim=-1) output_str = processor.batch_decode(pred_ids)[0].lower() ref_str = re.sub(chars_to_ignore_regex, "", sample["sentence"]).lower() # replace long empty strings by a single string ref_str = " ".join(ref_str.split()) print(f"Pred: {output_str} | Target: {ref_str}") print(50 * "=") references.append(ref_str) predictions.append(output_str) print(f"WER: {wer.compute(predictions=predictions, references=references) * 100}") print(f"CER: {cer.compute(predictions=predictions, references=references) * 100}")