comodoro commited on
Commit
ca1377c
1 Parent(s): ad6ba90

Model with more data

Browse files
Files changed (4) hide show
  1. eval.py +20 -5
  2. language_model/attrs.json +1 -1
  3. train.ipynb +0 -0
  4. vocab.json +1 -1
eval.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset, load_metric, Audio, Dataset
3
- from transformers import pipeline, AutoFeatureExtractor
 
4
  import re
5
  import argparse
6
  import unicodedata
@@ -106,18 +107,29 @@ def main(args):
106
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
107
 
108
  # for testing: only process the first two examples as a test
109
- # dataset = dataset.select(range(10))
 
110
 
111
- # load processor
112
  feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
 
113
  sampling_rate = feature_extractor.sampling_rate
114
 
115
  # resample audio
116
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
 
 
 
 
 
 
 
117
 
118
- # load eval pipeline
119
- asr = pipeline("automatic-speech-recognition", model=args.model_id)
120
 
 
 
 
121
  # map function to decode audio
122
  def map_to_pred(batch):
123
  prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
@@ -158,6 +170,9 @@ if __name__ == "__main__":
158
  parser.add_argument(
159
  "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
160
  )
 
 
 
161
  args = parser.parse_args()
162
 
163
  main(args)
1
  #!/usr/bin/env python3
2
  from datasets import load_dataset, load_metric, Audio, Dataset
3
+ from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2ForCTC
4
+ import os
5
  import re
6
  import argparse
7
  import unicodedata
107
  dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
108
 
109
  # for testing: only process the first two examples as a test
110
+ if args.limit:
111
+ dataset = dataset.select(range(limit))
112
 
 
113
  feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
114
+ # load processor
115
  sampling_rate = feature_extractor.sampling_rate
116
 
117
  # resample audio
118
  dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
119
+
120
+ asr = None
121
+
122
+ if os.path.exists(args.model_id):
123
+ model = Wav2Vec2ForCTC.from_pretrained(args.model_id)
124
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id)
125
+
126
 
127
+ # load eval pipeline
128
+ asr = pipeline("automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
129
 
130
+ else:
131
+ asr = pipeline("automatic-speech-recognition", model=args.model_id)
132
+
133
  # map function to decode audio
134
  def map_to_pred(batch):
135
  prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
170
  parser.add_argument(
171
  "--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
172
  )
173
+ parser.add_argument(
174
+ "--limit", type=int, help="Not required. If greater than zero, select a subset of this size from the dataset.", default=0
175
+ )
176
  args = parser.parse_args()
177
 
178
  main(args)
language_model/attrs.json CHANGED
@@ -1 +1 @@
1
- {"alpha": 0.5, "beta": 1.5, "unk_score_offset": -10.0, "score_boundary": true}
1
+ {"alpha": 0.9, "beta": 2.5, "unk_score_offset": -10.0, "score_boundary": true}
train.ipynb ADDED
The diff for this file is too large to render. See raw diff
vocab.json CHANGED
@@ -1 +1 @@
1
- {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "á": 27, "é": 28, "í": 29, "ó": 30, "ú": 31, "ý": 32, "č": 33, "ď": 34, "ě": 35, "ň": 36, "ř": 37, "š": 38, "ť": 39, "ů": 40, "ž": 41, "|": 0, "[UNK]": 42, "[PAD]": 43}
1
+ {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e1": 27, "\u00e9": 28, "\u00ed": 29, "\u00f3": 30, "\u00fa": 31, "\u00fd": 32, "\u010d": 33, "\u010f": 34, "\u011b": 35, "\u0148": 36, "\u0159": 37, "\u0161": 38, "\u0165": 39, "\u016f": 40, "\u017e": 41, "|": 0, "[UNK]": 42, "[PAD]": 43}