versae commited on
Commit
095b715
1 Parent(s): f625488

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +2 -8
eval.py CHANGED
@@ -6,7 +6,7 @@ from typing import Dict
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
- from transformers import AutoFeatureExtractor, AutoModel, AutoTokenizer, pipeline
10
 
11
 
12
  def log_results(result: Dataset, args: Dict[str, str]):
@@ -81,7 +81,6 @@ def normalize_text(text: str) -> str:
81
  def main(args):
82
  # load dataset
83
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
84
- #dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True).filter(lambda entry: re.search("nb-nn", entry["sentence_language_code"], flags=re.IGNORECASE))
85
 
86
  # for testing: only process the first two examples as a test
87
  # dataset = dataset.select(range(10))
@@ -96,12 +95,7 @@ def main(args):
96
  # load eval pipeline
97
  if args.device is None:
98
  args.device = 0 if torch.cuda.is_available() else -1
99
- asr = pipeline("automatic-speech-recognition",
100
- model=AutoModel.from_pretrained(args.model_id),
101
- tokenizer=AutoTokenizer.from_pretrained(args.model_id),
102
- feature_extractor=feature_extractor,
103
- device=args.device
104
- )
105
 
106
  # map function to decode audio
107
  def map_to_pred(batch):
 
6
  import torch
7
  from datasets import Audio, Dataset, load_dataset, load_metric
8
 
9
+ from transformers import AutoFeatureExtractor, pipeline
10
 
11
 
12
  def log_results(result: Dataset, args: Dict[str, str]):
 
81
  def main(args):
82
  # load dataset
83
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
 
84
 
85
  # for testing: only process the first two examples as a test
86
  # dataset = dataset.select(range(10))
 
95
  # load eval pipeline
96
  if args.device is None:
97
  args.device = 0 if torch.cuda.is_available() else -1
98
+ asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
 
 
 
 
 
99
 
100
  # map function to decode audio
101
  def map_to_pred(batch):