anuragshas commited on
Commit
5078d07
1 Parent(s): 66ceace

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +12 -9
eval.py CHANGED
@@ -4,6 +4,7 @@ import re
4
  import unicodedata
5
  from typing import Dict
6
 
 
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
  from transformers import AutoFeatureExtractor, pipeline
@@ -84,10 +85,12 @@ def main(args):
84
  # resample audio
85
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
86
 
87
- # load eval pipeline
88
- if args.device is None:
89
- args.device = 0 if torch.cuda.is_available() else -1
90
- asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
 
 
91
 
92
  # map function to decode audio
93
  def map_to_pred(batch):
@@ -150,11 +153,11 @@ if __name__ == "__main__":
150
  action="store_true",
151
  help="If defined, write outputs to log file for analysis.",
152
  )
153
- parser.add_argument(
154
- "--device",
155
- type=int,
156
- default=None,
157
- help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
158
  )
159
  args = parser.parse_args()
160
 
 
4
  import unicodedata
5
  from typing import Dict
6
 
7
+ import torch
8
  from datasets import Audio, Dataset, load_dataset, load_metric
9
 
10
  from transformers import AutoFeatureExtractor, pipeline
 
85
  # resample audio
86
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
87
 
88
+ # load eval pipeline
89
+ if args.device is None:
90
+ args.device = 0 if torch.cuda.is_available() else -1
91
+ asr = pipeline(
92
+ "automatic-speech-recognition", model=args.model_id, device=args.device
93
+ )
94
 
95
  # map function to decode audio
96
  def map_to_pred(batch):
 
153
  action="store_true",
154
  help="If defined, write outputs to log file for analysis.",
155
  )
156
+ parser.add_argument(
157
+ "--device",
158
+ type=int,
159
+ default=None,
160
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
161
  )
162
  args = parser.parse_args()
163