anuragshas commited on
Commit
3a559aa
1 Parent(s): 0b33ed3

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +12 -1
eval.py CHANGED
@@ -4,6 +4,7 @@ import re
4
  import unicodedata
5
  from typing import Dict
6
 
 
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
  from transformers import AutoFeatureExtractor, pipeline
@@ -90,7 +91,11 @@ def main(args):
90
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
91
 
92
  # load eval pipeline
93
- asr = pipeline("automatic-speech-recognition", model=args.model_id, device=0)
 
 
 
 
94
 
95
  # map function to decode audio
96
  def map_to_pred(batch):
@@ -153,6 +158,12 @@ if __name__ == "__main__":
153
  action="store_true",
154
  help="If defined, write outputs to log file for analysis.",
155
  )
 
 
 
 
 
 
156
  args = parser.parse_args()
157
 
158
  main(args)
 
4
  import unicodedata
5
  from typing import Dict
6
 
7
+ import torch
8
  from datasets import Audio, Dataset, load_dataset, load_metric
9
 
10
  from transformers import AutoFeatureExtractor, pipeline
 
91
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
92
 
93
  # load eval pipeline
94
+ if args.device is None:
95
+ args.device = 0 if torch.cuda.is_available() else -1
96
+ asr = pipeline(
97
+ "automatic-speech-recognition", model=args.model_id, device=args.device
98
+ )
99
 
100
  # map function to decode audio
101
  def map_to_pred(batch):
 
158
  action="store_true",
159
  help="If defined, write outputs to log file for analysis.",
160
  )
161
+ parser.add_argument(
162
+ "--device",
163
+ type=int,
164
+ default=None,
165
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
166
+ )
167
  args = parser.parse_args()
168
 
169
  main(args)