infinitejoy commited on
Commit
f49535d
1 Parent(s): 08fd86d

Upload eval.py

Browse files
Files changed (1) hide show
  1. eval.py +11 -13
eval.py CHANGED
@@ -3,6 +3,7 @@ import argparse
3
  import re
4
  from typing import Dict
5
 
 
6
  from datasets import Audio, Dataset, load_dataset, load_metric
7
 
8
  from transformers import AutoFeatureExtractor, pipeline
@@ -51,18 +52,7 @@ def normalize_text(text: str) -> str:
51
 
52
  chars_to_ignore_regex = '[,?.!\-\;\:"“%‘”�—’…–]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
53
 
54
- text = re.sub(chars_to_ignore_regex, "", text.lower()) \
55
- .replace("\\\\punkt", "") \
56
- .replace("\\\\komma", "") \
57
- .replace("è", "e") \
58
- .replace("é", "e") \
59
- .replace("î", "i") \
60
- .replace("ü", "u") \
61
- .replace("ÿ", "y") \
62
- .replace("ô", "o") \
63
- .replace("\\", "") \
64
- .replace("/", "") \
65
- .replace("|", "")
66
 
67
  # In addition, we can normalize the target text, e.g. removing new lines characters etc...
68
  # note that order is important here!
@@ -89,7 +79,9 @@ def main(args):
89
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
90
 
91
  # load eval pipeline
92
- asr = pipeline("automatic-speech-recognition", model=args.model_id)
 
 
93
 
94
  # map function to decode audio
95
  def map_to_pred(batch):
@@ -134,6 +126,12 @@ if __name__ == "__main__":
134
  parser.add_argument(
135
  "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
136
  )
 
 
 
 
 
 
137
  args = parser.parse_args()
138
 
139
  main(args)
 
3
  import re
4
  from typing import Dict
5
 
6
+ import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
  from transformers import AutoFeatureExtractor, pipeline
 
52
 
53
  chars_to_ignore_regex = '[,?.!\-\;\:"“%‘”�—’…–]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
54
 
55
+ text = re.sub(chars_to_ignore_regex, "", text.lower())
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # In addition, we can normalize the target text, e.g. removing new lines characters etc...
58
  # note that order is important here!
 
79
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
80
 
81
  # load eval pipeline
82
+ if args.device is None:
83
+ args.device = 0 if torch.cuda.is_available() else -1
84
+ asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
85
 
86
  # map function to decode audio
87
  def map_to_pred(batch):
 
126
  parser.add_argument(
127
  "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
128
  )
129
+ parser.add_argument(
130
+ "--device",
131
+ type=int,
132
+ default=None,
133
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
134
+ )
135
  args = parser.parse_args()
136
 
137
  main(args)