new_tools / hf_whisper_meanwhile.py
patrickvonplaten's picture
improve
84a4607
#!/usr/bin/env python3
from datasets import load_dataset
from datasets import Audio
import numpy as np
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
import torch
from jiwer import wer
import whisper
PRECISION = torch.float16
PRECISION = torch.float32
DO_COND = True
# model_id = "openai/whisper-tiny"
model_id = "openai/whisper-tiny.en"
# model_id = "openai/whisper-large-v2"
processor = AutoProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=PRECISION)
model = model.to("cuda")
model_orig = whisper.load_model(model_id.split("whisper-")[-1])
# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
# ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = load_dataset("distil-whisper/earnings21", "full")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
num_samples = 3
start = 2
audios = [x['array'] for x in ds[start:num_samples]["audio"]]
for name in ["text", "transcription"]:
if name in ds.column_names:
labels = ds[start:num_samples][name]
break
for audio, label in zip(audios, labels):
inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
inputs = inputs.to("cuda", PRECISION)
if inputs["input_features"].shape[-1] < 3000:
continue
result = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=DO_COND, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None)
gen_length = 448
result_hf = model.generate(**inputs, condition_on_prev_tokens=DO_COND, max_new_tokens=gen_length, return_timestamps=True)
decoded = processor.batch_decode(result_hf, skip_special_tokens=True)
result = model.generate(**inputs, condition_on_previous_tokens=False, max_new_tokens=gen_length, return_timestamps=True)
result_text_norm = processor.tokenizer._normalize(result["text"])
decoded_norm = processor.tokenizer._normalize(decoded[0])
label_norm = processor.tokenizer._normalize(label)
wer_orig = wer(label_norm, result_text_norm)
wer_hf = wer(label_norm, decoded_norm)
print("Cond:\n", decoded_norm)
print(50 * "-")
# print("Not cond:\n", decoded_2)
# print(50 * "-")
print("Orig Cond:\n", result_text_norm)
print(50 * "-")
# print("Orig Not cond:\n", [result_2["text"]])
# print(50 * "=")
print("Label:\n", label_norm)
# break
#
print("Result:")
print("WER Orig", wer_orig)
print("WER HF", wer_hf)
print("Done")