#!/usr/bin/env python3 from datasets import load_dataset from datasets import Audio import numpy as np from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline import torch from jiwer import wer import whisper PRECISION = torch.float16 PRECISION = torch.float32 DO_COND = True # model_id = "openai/whisper-tiny" model_id = "openai/whisper-tiny.en" # model_id = "openai/whisper-large-v2" processor = AutoProcessor.from_pretrained(model_id) model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=PRECISION) model = model.to("cuda") model_orig = whisper.load_model(model_id.split("whisper-")[-1]) # ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") # ds = load_dataset("distil-whisper/meanwhile", "default")["test"] ds = load_dataset("distil-whisper/earnings21", "full")["test"] ds = ds.cast_column("audio", Audio(sampling_rate=16000)) num_samples = 3 start = 2 audios = [x['array'] for x in ds[start:num_samples]["audio"]] for name in ["text", "transcription"]: if name in ds.column_names: labels = ds[start:num_samples][name] break for audio, label in zip(audios, labels): inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) inputs = inputs.to("cuda", PRECISION) if inputs["input_features"].shape[-1] < 3000: continue result = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=DO_COND, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None) gen_length = 448 result_hf = model.generate(**inputs, condition_on_prev_tokens=DO_COND, max_new_tokens=gen_length, return_timestamps=True) decoded = processor.batch_decode(result_hf, skip_special_tokens=True) result = model.generate(**inputs, condition_on_previous_tokens=False, max_new_tokens=gen_length, return_timestamps=True) result_text_norm = processor.tokenizer._normalize(result["text"]) decoded_norm = processor.tokenizer._normalize(decoded[0]) label_norm = processor.tokenizer._normalize(label) wer_orig = wer(label_norm, result_text_norm) wer_hf = wer(label_norm, decoded_norm) print("Cond:\n", decoded_norm) print(50 * "-") # print("Not cond:\n", decoded_2) # print(50 * "-") print("Orig Cond:\n", result_text_norm) print(50 * "-") # print("Orig Not cond:\n", [result_2["text"]]) # print(50 * "=") print("Label:\n", label_norm) # break # print("Result:") print("WER Orig", wer_orig) print("WER HF", wer_hf) print("Done")