patrickvonplaten commited on
Commit
a246192
1 Parent(s): 265ea69
Files changed (1) hide show
  1. run_ctc_model.py +40 -12
run_ctc_model.py CHANGED
@@ -3,30 +3,58 @@ import sys
3
  import torch
4
 
5
  from transformers import AutoModelForCTC, AutoProcessor
6
- from datasets import load_dataset
7
  import torchaudio.functional as F
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  model_id = sys.argv[1]
12
  lang = sys.argv[2]
 
 
 
 
 
13
 
14
  ds = load_dataset("common_voice", lang, split="test", streaming=True)
 
15
 
16
- sample = next(iter(ds))
 
17
 
18
- resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
 
 
 
 
 
19
 
20
- model = AutoModelForCTC.from_pretrained(model_id).to(device)
21
- processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- input_values = processor(resampled_audio, return_tensors="pt").input_values
24
 
25
- with torch.no_grad():
26
- logits = model(input_values.to(device)).logits
 
27
 
28
- prediction_ids = torch.argmax(logits, dim=-1)
29
- transcription = processor.batch_decode(prediction_ids)
30
 
31
- print(f"Correct: {sample['sentence']}")
32
- print(f"Predict: {transcription}")
 
3
  import torch
4
 
5
  from transformers import AutoModelForCTC, AutoProcessor
6
+ from datasets import load_dataset, load_metric
7
  import torchaudio.functional as F
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  model_id = sys.argv[1]
12
  lang = sys.argv[2]
13
+ lang_phoneme = sys.argv[3]
14
+ num_samples = int(sys.argv[4])
15
+
16
+ model = AutoModelForCTC.from_pretrained(model_id).to(device)
17
+ processor = AutoProcessor.from_pretrained(model_id)
18
 
19
  ds = load_dataset("common_voice", lang, split="test", streaming=True)
20
+ sample_iter = iter(ds)
21
 
22
+ wer = load_metric("wer")
23
+ cer = load_metric("cer")
24
 
25
+ targets_ids = []
26
+ predictions_ids = []
27
+ for i in range(num_samples):
28
+ sample = next(sample_iter)
29
+ resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
30
+ input_values = processor(resampled_audio, return_tensors="pt").input_values
31
 
32
+ with torch.no_grad():
33
+ logits = model(input_values.to(device)).logits
34
+
35
+ prediction_ids = torch.argmax(logits, dim=-1)
36
+ transcription = processor.batch_decode(prediction_ids)
37
+
38
+ print(f"Correct: {sample['sentence']}")
39
+ print(f"Predict: {transcription}")
40
+ print(20 * '-')
41
+
42
+ predictions_ids.append(prediction_ids[0].tolist())
43
+
44
+ kwargs = {}
45
+ if len(lang_phoneme) > 0:
46
+ kwargs["phonemizer_lang"] = lang_phoneme
47
+
48
+ targets_ids.append(processor.tokenizer(sample["sentence"], **kwargs).input_ids)
49
 
50
+ print("Compute metrics.....")
51
 
52
+ import ipdb; ipdb.set_trace()
53
+ transcriptions = processor.batch_decode(predictions_ids)
54
+ targets_str = processor.batch_decode(targets_ids, group_tokens=False)
55
 
56
+ wer = wer.compute(predictions=transcriptions, references=targets_str)
57
+ cer = cer.compute(predictions=transcriptions, references=targets_str)
58
 
59
+ print("wer", wer)
60
+ print("cer", cer)