File size: 2,383 Bytes
de07127 50f4595 4afec78 b4725a8 decc59e 3bef3fb 664eb76 0b5b7f4 8a965da 3b57b43 3826e01 973bb39 62ac43e d47af16 62ac43e dca4d0e 3bef3fb db75012 3bef3fb db75012 3ce5824 75a25c7 e305646 66476bb 5475a7a 79dce08 75a25c7 3bef3fb 75a25c7 3bef3fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
import sys
os.system("pip install transformers==4.27.0")
os.system("pip install numpy==1.23")
from transformers import pipeline, WhisperModel, WhisperTokenizer, WhisperFeatureExtractor, AutoFeatureExtractor, AutoProcessor, WhisperConfig
os.system("pip install jiwer")
from jiwer import wer
os.system("pip install datasets[audio]")
from evaluate import evaluator
from datasets import load_dataset, Audio, disable_caching, set_caching_enabled
set_caching_enabled(False)
disable_caching()
huggingface_token = os.environ["huggingface_token"]
pipe = pipeline(model="mskov/whisper_miso")
model = WhisperModel.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
feature_extractor = AutoFeatureExtractor.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
miso_tokenizer = WhisperTokenizer.from_pretrained("mskov/whisper_miso", use_auth_token=huggingface_token)
dataset = load_dataset("mskov/miso_test", split="test").cast_column("audio", Audio(sampling_rate=16000))
print(dataset, "and at 0[audio][array] ", dataset[0]["audio"]["array"], type(dataset[0]["audio"]["array"]), "and at audio : ", dataset[0]["audio"])
inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
print("inputs ::: ", inputs, "and dataset type for good measure: ", type(dataset))
tempDataset = dataset[0]["audio"]["array"].tostring()
tokenized_dataset = miso_tokenizer(tempDataset) # Tokenize the dataset
input_ids = features.input_ids
attention_mask = features.attention_mask
# Evaluate the model
model.eval()
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# Convert predicted token IDs back to text
predicted_text = tokenizer.batch_decode(outputs.logits.argmax(dim=-1), skip_special_tokens=True)
# Get ground truth labels from the dataset
labels = dataset["audio"] # Replace "labels" with the appropriate key in your dataset
# Compute WER
wer_score = wer(labels, predicted_text)
# Print or return WER score
print(f"Word Error Rate (WER): {wer_score}")
'''
print("check check")
print(inputs)
input_features = inputs.input_features
decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
list(last_hidden_state.shape)
print(list(last_hidden_state.shape))
''' |