Harveenchadha commited on
Commit
8594e37
1 Parent(s): 249704b

Update evaluations/eval.py

Browse files
Files changed (1) hide show
  1. evaluations/eval.py +9 -1
evaluations/eval.py CHANGED
@@ -93,7 +93,9 @@ def main(args):
93
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
94
 
95
  # load eval pipeline
96
- asr = pipeline("automatic-speech-recognition", model=args.model_id)
 
 
97
 
98
  # map function to decode audio
99
  def map_to_pred(batch):
@@ -140,6 +142,12 @@ if __name__ == "__main__":
140
  parser.add_argument(
141
  "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
142
  )
 
 
 
 
 
 
143
  args = parser.parse_args()
144
 
145
  main(args)
 
93
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
94
 
95
  # load eval pipeline
96
+ if args.device is None:
97
+ args.device = 0 if torch.cuda.is_available() else -1
98
+ asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
99
 
100
  # map function to decode audio
101
  def map_to_pred(batch):
 
142
  parser.add_argument(
143
  "--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
144
  )
145
+ parser.add_argument(
146
+ "--device",
147
+ type=int,
148
+ default=None,
149
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
150
+ )
151
  args = parser.parse_args()
152
 
153
  main(args)