patrickvonplaten commited on
Commit
dc89bde
1 Parent(s): cd9f529
Files changed (1) hide show
  1. hf_whisper_meanwhile.py +81 -0
hf_whisper_meanwhile.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from datasets import load_dataset
3
+ from datasets import Audio
4
+ import numpy as np
5
+ from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
6
+ import torch
7
+ from jiwer import wer
8
+ import whisper
9
+
10
+ PRECISION = torch.float16
11
+ PRECISION = torch.float32
12
+ DO_COND = True
13
+
14
+ # model_id = "openai/whisper-tiny"
15
+ model_id = "openai/whisper-tiny.en"
16
+ # model_id = "openai/whisper-large-v2"
17
+
18
+ processor = AutoProcessor.from_pretrained(model_id)
19
+ model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=PRECISION)
20
+ model = model.to("cuda")
21
+
22
+ model_orig = whisper.load_model(model_id.split("whisper-")[-1])
23
+
24
+ # ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
25
+ # ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
26
+ ds = load_dataset("distil-whisper/earnings21", "full")["test"]
27
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
28
+
29
+ num_samples = 3
30
+ start = 2
31
+
32
+ audios = [x['array'] for x in ds[start:num_samples]["audio"]]
33
+ for name in ["text", "transcription"]:
34
+ if name in ds.column_names:
35
+ labels = ds[start:num_samples][name]
36
+ break
37
+
38
+ for audio, label in zip(audios, labels):
39
+ inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
40
+ inputs = inputs.to("cuda", PRECISION)
41
+
42
+ if inputs["input_features"].shape[-1] < 3000:
43
+ continue
44
+
45
+ result = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=DO_COND, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None)
46
+
47
+ gen_length = 448
48
+ result_hf = model.generate(**inputs, condition_on_prev_tokens=DO_COND, max_new_tokens=gen_length, return_timestamps=True)
49
+ decoded = processor.batch_decode(result_hf, skip_special_tokens=True)
50
+
51
+ # result = model.generate(**inputs, condition_on_previous_tokens=False, max_new_tokens=gen_length, return_timestamps=True)
52
+ # decoded_2 = processor.batch_decode(result)
53
+ # print(50 * "-")
54
+
55
+ # result_2 = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=False, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None)
56
+
57
+ result_text_norm = processor.tokenizer._normalize(result["text"])
58
+ decoded_norm = processor.tokenizer._normalize(decoded[0])
59
+ label_norm = processor.tokenizer._normalize(label)
60
+
61
+
62
+ wer_orig = wer(label_norm, result_text_norm)
63
+ wer_hf = wer(label_norm, decoded_norm)
64
+
65
+ print("Cond:\n", decoded_norm)
66
+ print(50 * "-")
67
+ # print("Not cond:\n", decoded_2)
68
+ # print(50 * "-")
69
+ print("Orig Cond:\n", result_text_norm)
70
+ print(50 * "-")
71
+ # print("Orig Not cond:\n", [result_2["text"]])
72
+ # print(50 * "=")
73
+ print("Label:\n", label_norm)
74
+
75
+ # break
76
+ #
77
+ print("Result:")
78
+ print("WER Orig", wer_orig)
79
+ print("WER HF", wer_hf)
80
+
81
+ print("Done")