Only the logits for the decoder_input_ids are returned, not for the actual input_features

#8
by joeyontour - opened

In the sample code from the model card, only the logits for the language token are returned and not the logits of the actual audio. I cannot use the generate function as I need the logits to compute the word level timestamps and to use it with a language model. Is there a way to obtain the logits?

>>> # Generate logits
>>> logits = model(input_features, decoder_input_ids = torch.tensor([[50258]])).logits 
>>> # take argmax and decode
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids)
['<|en|>']

You have to run model.generate(...) to get more than the first token

In this case, I don't understand how the evaluation was performed. Here the logits are extracted and the ids are decoded. When I run this locally, I indeed only get an empty transcription after normalization, so I'm wondering how it was evaluated.

>>> librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").to("cuda")
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-large")
>>> def map_to_pred(batch):
>>>     input_features = processor(batch["audio"]["array"], return_tensors="pt").input_features
>>>     with torch.no_grad():
>>>         logits = model(input_features.to("cuda")).logits
>>>     predicted_ids = torch.argmax(logits, dim=-1)
>>>     transcription = processor.batch_decode(predicted_ids, normalize = True)
>>>     batch['text'] = processor.tokenizer._normalize(batch['text'])
>>>     batch["transcription"] = transcription
>>>     return batch
>>> result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
>>> print("WER:", wer(result["text"], result["transcription"]))
0.030003583080317572

Hey ! You are correct, the snippet is wrong, we indeed used generate! Will fix the evaluation code . Thanks for the catch

Sign up or log in to comment