File size: 2,657 Bytes
dc89bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a4607
dc89bde
 
 
 
 
f738f98
dc89bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
from datasets import load_dataset
from datasets import Audio
import numpy as np
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
import torch
from jiwer import wer
import whisper

PRECISION = torch.float16
PRECISION = torch.float32
DO_COND = True

# model_id = "openai/whisper-tiny"
model_id = "openai/whisper-tiny.en"
# model_id = "openai/whisper-large-v2"

processor = AutoProcessor.from_pretrained(model_id)
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=PRECISION)
model = model.to("cuda")

model_orig = whisper.load_model(model_id.split("whisper-")[-1])

# ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
# ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = load_dataset("distil-whisper/earnings21", "full")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))

num_samples = 3
start = 2

audios = [x['array'] for x in ds[start:num_samples]["audio"]]
for name in ["text", "transcription"]:
    if name in ds.column_names:
        labels = ds[start:num_samples][name]
        break

for audio, label in zip(audios, labels):
    inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
    inputs = inputs.to("cuda", PRECISION)

    if inputs["input_features"].shape[-1] < 3000:
        continue

    result = model_orig.transcribe(audio.astype(dtype=np.float32), condition_on_previous_text=DO_COND, temperature=0.0, logprob_threshold=None, compression_ratio_threshold=None, no_speech_threshold=None)

    gen_length = 448
    result_hf = model.generate(**inputs, condition_on_prev_tokens=DO_COND, max_new_tokens=gen_length, return_timestamps=True)
    decoded = processor.batch_decode(result_hf, skip_special_tokens=True)

    result = model.generate(**inputs, condition_on_previous_tokens=False, max_new_tokens=gen_length, return_timestamps=True)

    result_text_norm = processor.tokenizer._normalize(result["text"])
    decoded_norm = processor.tokenizer._normalize(decoded[0])
    label_norm = processor.tokenizer._normalize(label)


    wer_orig = wer(label_norm, result_text_norm)
    wer_hf = wer(label_norm, decoded_norm)

    print("Cond:\n", decoded_norm)
    print(50 * "-")
    # print("Not cond:\n", decoded_2)
    # print(50 * "-")
    print("Orig Cond:\n", result_text_norm)
    print(50 * "-")
    # print("Orig Not cond:\n", [result_2["text"]])
    # print(50 * "=")
    print("Label:\n", label_norm)

    # break
    #
    print("Result:")
    print("WER Orig", wer_orig)
    print("WER HF", wer_hf)

    print("Done")