anuragshas commited on
Commit
66ceace
1 Parent(s): f52a598

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +10 -2
eval.py CHANGED
@@ -84,8 +84,10 @@ def main(args):
84
  # resample audio
85
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
86
 
87
- # load eval pipeline
88
- asr = pipeline("automatic-speech-recognition", model=args.model_id, device=0)
 
 
89
 
90
  # map function to decode audio
91
  def map_to_pred(batch):
@@ -148,6 +150,12 @@ if __name__ == "__main__":
148
  action="store_true",
149
  help="If defined, write outputs to log file for analysis.",
150
  )
 
 
 
 
 
 
151
  args = parser.parse_args()
152
 
153
  main(args)
 
84
  # resample audio
85
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
86
 
87
+ # load eval pipeline
88
+ if args.device is None:
89
+ args.device = 0 if torch.cuda.is_available() else -1
90
+ asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
91
 
92
  # map function to decode audio
93
  def map_to_pred(batch):
 
150
  action="store_true",
151
  help="If defined, write outputs to log file for analysis.",
152
  )
153
+ parser.add_argument(
154
+ "--device",
155
+ type=int,
156
+ default=None,
157
+ help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
158
+ )
159
  args = parser.parse_args()
160
 
161
  main(args)