Mismatching WER and samll BUGs

#9
by yourex - opened

Hi, I got some issues while running the sample code. When I directly ran it, I got the error as following:

TypeError                                 Traceback (most recent call last)
     19     batch["transcription"] = transcription
     20     return batch
---> 22 result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
     24 #flattened_list = [item[0] for item in result["transcription"]]
     26 print("WER:", wer(result["text"], result["transcription"]))

File \datasets\arrow_dataset.py:592, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    590     self: "Dataset" = kwargs.pop("self")
    591 # apply actual function
--> 592 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    593 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    594 for dataset in datasets:
    595     # Remove task templates if a column mapping of the template is no longer valid

File \datasets\arrow_dataset.py:557, in transmit_format.<locals>.wrapper(*args, **kwargs)
    550 self_format = {
    551     "type": self._format_type,
    552     "format_kwargs": self._format_kwargs,
    553     "columns": self._format_columns,
    554     "output_all_columns": self._output_all_columns,
    555 }
    556 # apply actual function
...
---> 13     input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
     14     with torch.no_grad():
     15         logits = model(input_values.to("cuda")).logits

TypeError: list indices must be integers or slices, not str

So I modified the sample code as:

result = librispeech_eval.map(map_to_pred, batched=False, batch_size=1, remove_columns=["audio"])

flattened_list = [item[0] for item in result["transcription"]]

print("WER:", wer(result["text"], flattened_list))

Then I got a mismatch WER as:

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
WER: 0.0338557516737675

Is this caused by not passing the sampling rate?

Would you know what might be happening @sanchit-gandhi ?

The sample code I ran is this one:

Evaluation
This code snippet shows how to evaluate facebook/wav2vec2-base-960h on LibriSpeech's "clean" and "other" test data.

from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer


librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])

print("WER:", wer(result["text"], result["transcription"]))

Hey @yourex ! You're correct in that the code-snippet is broken. I've opened a PR to correct the codesnippet and activate batching here: #10 (and for the large model here: PR 4, and self-trained model here: PR 6). Perhaps @patrickvonplaten could merge these PRs? (Wav2Vec2 is maintained entirely by HF on the Hub)

The WER you calculated with your modified codesnippet is entirely correct. You obtained a WER of 0.03385, or 3.4%, which matches the expected results (see bottom of section on evaluation). The sampling rate of LibriSpeech is 16kHz, which matches the sampling rate of the Wav2Vec2 feature extractor, so in this case there are no pre-processing errors. However, it is good practice to pass the sampling rate to prevent silent errors, as is done on the update codesnippet.

yourex changed discussion status to closed

Sign up or log in to comment