patrickvonplaten commited on
Commit
da69cd8
1 Parent(s): 685ce0f
Files changed (2) hide show
  1. README.md +4 -0
  2. create_confidence_scores.py +39 -12
README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ # Confidence Scoring
2
+
3
+ Read https://x-lance.sjtu.edu.cn/papers/zhc00-chen-icassp17.pdf
4
+ Run create_confidence_scores.py
create_confidence_scores.py CHANGED
@@ -1,30 +1,57 @@
1
  #!/usr/bin/env python3
2
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  from datasets import load_dataset
4
  import datasets
5
  import torch
 
6
 
7
- model = Wav2Vec2ForCTC.from_pretrained("facebook/data2vec-audio-base-10m")
8
- processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-10m")
9
 
10
- minds14 = load_dataset("PolyAI/minds14", "en-US", split="train")
11
- minds14 = minds14.cast_column("audio", datasets.Audio(sampling_rate=16_000))
12
 
13
- input_values = processor(minds14[0]["audio"]["array"], return_tensors="pt", sampling_rate=minds14[0]["audio"]["sampling_rate"]).input_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  with torch.no_grad():
16
- logits = model(input_values).logits
17
  scores = torch.nn.functional.softmax(logits, dim=-1)
18
  pred_ids = torch.argmax(logits, dim=-1)
19
  pred_scores = scores.gather(1, pred_ids.unsqueeze(-1))[:, :, 0]
20
 
21
  output = processor.batch_decode(pred_ids, output_word_offsets=True)
22
 
 
23
  # add confidence
24
- def confidence_score(word_dict):
25
- probs = pred_scores[0, word_dict["start_offset"]: word_dict["end_offset"]]
26
- return torch.mean(probs)
27
 
28
- output["confidence_scores"] = {d["word"]: confidence_score(d) for d in output.word_offsets[0]}
29
 
30
- print(output["confidence_scores"])
 
 
 
 
1
  #!/usr/bin/env python3
2
+ from transformers import AutoModelForCTC, AutoProcessor
3
  from datasets import load_dataset
4
  import datasets
5
  import torch
6
+ import sys
7
 
8
+ model_id = sys.argv[1]
 
9
 
10
+ model = AutoModelForCTC.from_pretrained(model_id)
11
+ processor = AutoProcessor.from_pretrained(model_id)
12
 
13
+ num_samples = 4
14
+
15
+ do_streaming = True
16
+
17
+ if do_streaming:
18
+ dataset = load_dataset("common_voice", "en", split="test", streaming=True)
19
+ dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
20
+
21
+ # iterate over dataset
22
+ dataset_iter = iter(dataset)
23
+ samples = [next(dataset_iter) for _ in range(num_samples)]
24
+
25
+ audio_samples = [s["audio"]["array"] for s in samples]
26
+ sampling_rate = set([s["audio"]["sampling_rate"] for s in samples]).pop()
27
+ text_samples = [s["sentence"] for s in samples]
28
+
29
+ else:
30
+ dataset = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
31
+ samples = dataset[:4]
32
+ audio_samples = [s["array"] for s in samples["audio"]]
33
+ sampling_rate = set([s["sampling_rate"] for s in samples["audio"]]).pop()
34
+ text_samples = samples["text"]
35
+
36
+ inputs = processor(audio_samples, return_tensors="pt", sampling_rate=sampling_rate, padding=True)
37
 
38
  with torch.no_grad():
39
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
40
  scores = torch.nn.functional.softmax(logits, dim=-1)
41
  pred_ids = torch.argmax(logits, dim=-1)
42
  pred_scores = scores.gather(1, pred_ids.unsqueeze(-1))[:, :, 0]
43
 
44
  output = processor.batch_decode(pred_ids, output_word_offsets=True)
45
 
46
+
47
  # add confidence
48
+ def confidence_score(word_dict, index):
49
+ probs = pred_scores[index, word_dict["start_offset"]: word_dict["end_offset"]]
50
+ return round(torch.mean(probs).item(), 4)
51
 
 
52
 
53
+ for i in range(num_samples):
54
+ print(20 * "=" + f"Output {i}" + 20 * "=")
55
+ print(text_samples[i])
56
+ print({d["word"]: confidence_score(d, i) for d in output.word_offsets[i]})
57
+ print("\n")