pascal lim commited on
Commit
9e32bde
1 Parent(s): da15453

update eval script with lm

Browse files
Files changed (1) hide show
  1. eval_lm.py +24 -25
eval_lm.py CHANGED
@@ -4,7 +4,7 @@ import re
4
  from typing import Dict
5
 
6
  from datasets import Audio, Dataset, load_dataset, load_metric
7
-
8
  from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2ProcessorWithLM
9
 
10
 
@@ -62,39 +62,38 @@ def normalize_text(text: str) -> str:
62
 
63
  return text
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def main(args):
67
  # load dataset
68
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
 
 
69
 
70
  # for testing: only process the first two examples as a test
71
  # dataset = dataset.select(range(10))
72
 
73
  # load processor
74
- processor = Wav2Vec2ProcessorWithLM.from_pretrained("Plim/")
75
 
76
- model = Wav2Vec2ForCTC.from_pretrained(model_id)
77
- feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
78
- sampling_rate = feature_extractor.sampling_rate
79
-
80
- # resample audio
81
- dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
82
-
83
- # load eval pipeline
84
- asr = pipeline("automatic-speech-recognition", model=args.model_id)
85
-
86
- # map function to decode audio
87
- def map_to_pred(batch):
88
- prediction = asr(
89
- batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
90
- )
91
-
92
- batch["prediction"] = prediction["text"]
93
- batch["target"] = normalize_text(batch["sentence"])
94
- return batch
95
 
96
  # run inference on all examples
97
- result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
98
 
99
  # compute and log_results
100
  # do not change function below
@@ -104,9 +103,9 @@ def main(args):
104
  if __name__ == "__main__":
105
  parser = argparse.ArgumentParser()
106
 
107
- parser.add_argument(
108
- "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
109
- )
110
  parser.add_argument(
111
  "--dataset",
112
  type=str,
4
  from typing import Dict
5
 
6
  from datasets import Audio, Dataset, load_dataset, load_metric
7
+ import torch
8
  from transformers import AutoFeatureExtractor, pipeline, Wav2Vec2ProcessorWithLM
9
 
10
 
62
 
63
  return text
64
 
65
+ def evaluate_with_lm(batch):
66
+ inputs = processor(batch["audio"]["array"], sampling_rate=16_000, return_tensors="pt", padding=True)
67
+
68
+ with torch.no_grad():
69
+ logits = model(**inputs.to('cuda')).logits
70
+ int_result = processor.batch_decode(logits.cpu().numpy())
71
+
72
+ batch["prediction"] = int_result.text
73
+ batch["target"] = normalize_text(batch["sentence"])
74
+
75
+ del int_result
76
+ torch.cuda.empty_cache()
77
+
78
+ return batch
79
 
80
  def main(args):
81
  # load dataset
82
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
83
+ # resample audio
84
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
85
 
86
  # for testing: only process the first two examples as a test
87
  # dataset = dataset.select(range(10))
88
 
89
  # load processor
90
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("./")
91
 
92
+ model = Wav2Vec2ForCTC.from_pretrained("./")
93
+ model.to('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # run inference on all examples
96
+ result = dataset.map(evaluate_with_lm, remove_columns=dataset.column_names)
97
 
98
  # compute and log_results
99
  # do not change function below
103
  if __name__ == "__main__":
104
  parser = argparse.ArgumentParser()
105
 
106
+ # parser.add_argument(
107
+ # "--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
108
+ # )
109
  parser.add_argument(
110
  "--dataset",
111
  type=str,