hf-test commited on
Commit
21443b2
1 Parent(s): 6db1d2d

add eval script

Browse files
Files changed (2) hide show
  1. eval.py +13 -7
  2. preprocessor_config.json +2 -1
eval.py CHANGED
@@ -2,6 +2,7 @@
2
  from datasets import load_dataset, load_metric, Audio
3
  from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
4
  import torch
 
5
 
6
  lang = "sv-SE"
7
  model_id = "./xls-r-300m-sv"
@@ -11,12 +12,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dataset = load_dataset("mozilla-foundation/common_voice_7_0", lang, split="test", use_auth_token=True)
12
  wer = load_metric("wer")
13
 
14
- dataset = dataset.select(range(100))
15
-
16
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
17
 
18
  model = AutoModelForCTC.from_pretrained(model_id).to(device)
19
- processor = Wav2Vec2Processor.from_pretrained(model_id)
 
 
 
20
 
21
 
22
  def map_to_pred(batch):
@@ -25,15 +27,19 @@ def map_to_pred(batch):
25
  with torch.no_grad():
26
  logits = model(input_values.to(device)).logits
27
 
28
- predicted_ids = torch.argmax(logits, dim=-1)
29
- transcription = processor.batch_decode(predicted_ids)[0]
 
 
 
 
30
  batch["transcription"] = transcription
 
31
  return batch
32
 
33
 
34
  result = dataset.map(map_to_pred, remove_columns=["audio"])
35
 
36
- import ipdb; ipdb.set_trace()
37
- wer_result = wer.compute(references=result["sentence"], predictions=result["transcription"])
38
 
39
  print("WER", wer_result)
 
2
  from datasets import load_dataset, load_metric, Audio
3
  from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
4
  import torch
5
+ import re
6
 
7
  lang = "sv-SE"
8
  model_id = "./xls-r-300m-sv"
 
12
  dataset = load_dataset("mozilla-foundation/common_voice_7_0", lang, split="test", use_auth_token=True)
13
  wer = load_metric("wer")
14
 
 
 
15
  dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
16
 
17
  model = AutoModelForCTC.from_pretrained(model_id).to(device)
18
+ processor = AutoProcessor.from_pretrained(model_id)
19
+
20
+
21
+ chars_to_ignore_regex = '[,?.!\-\;\:\"“%‘”�—’…–]' # change to the ignored characters of your fine-tuned model
22
 
23
 
24
  def map_to_pred(batch):
 
27
  with torch.no_grad():
28
  logits = model(input_values.to(device)).logits
29
 
30
+ if processor.__class__.__name__ == "Wav2Vec2Processor":
31
+ predicted_ids = torch.argmax(logits, dim=-1)
32
+ transcription = processor.batch_decode(predicted_ids)[0]
33
+ else:
34
+ transcription = processor.batch_decode(logits.cpu().numpy()).text[0]
35
+
36
  batch["transcription"] = transcription
37
+ batch["text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"].lower())
38
  return batch
39
 
40
 
41
  result = dataset.map(map_to_pred, remove_columns=["audio"])
42
 
43
+ wer_result = wer.compute(references=result["text"], predictions=result["transcription"])
 
44
 
45
  print("WER", wer_result)
preprocessor_config.json CHANGED
@@ -6,5 +6,6 @@
6
  "padding_side": "right",
7
  "padding_value": 0,
8
  "return_attention_mask": true,
9
- "sampling_rate": 16000
 
10
  }
 
6
  "padding_side": "right",
7
  "padding_value": 0,
8
  "return_attention_mask": true,
9
+ "sampling_rate": 16000,
10
+ "processor_class": "Wav2Vec2ProcessorWithLM"
11
  }