codesnippets / create_confidence_scores.py
patrickvonplaten's picture
upload
da69cd8
#!/usr/bin/env python3
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset
import datasets
import torch
import sys
model_id = sys.argv[1]
model = AutoModelForCTC.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
num_samples = 4
do_streaming = True
if do_streaming:
dataset = load_dataset("common_voice", "en", split="test", streaming=True)
dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
# iterate over dataset
dataset_iter = iter(dataset)
samples = [next(dataset_iter) for _ in range(num_samples)]
audio_samples = [s["audio"]["array"] for s in samples]
sampling_rate = set([s["audio"]["sampling_rate"] for s in samples]).pop()
text_samples = [s["sentence"] for s in samples]
else:
dataset = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
samples = dataset[:4]
audio_samples = [s["array"] for s in samples["audio"]]
sampling_rate = set([s["sampling_rate"] for s in samples["audio"]]).pop()
text_samples = samples["text"]
inputs = processor(audio_samples, return_tensors="pt", sampling_rate=sampling_rate, padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_ids = torch.argmax(logits, dim=-1)
pred_scores = scores.gather(1, pred_ids.unsqueeze(-1))[:, :, 0]
output = processor.batch_decode(pred_ids, output_word_offsets=True)
# add confidence
def confidence_score(word_dict, index):
probs = pred_scores[index, word_dict["start_offset"]: word_dict["end_offset"]]
return round(torch.mean(probs).item(), 4)
for i in range(num_samples):
print(20 * "=" + f"Output {i}" + 20 * "=")
print(text_samples[i])
print({d["word"]: confidence_score(d, i) for d in output.word_offsets[i]})
print("\n")