Model with more data
Browse files- eval.py +20 -5
- language_model/attrs.json +1 -1
- train.ipynb +0 -0
- vocab.json +1 -1
eval.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
from datasets import load_dataset, load_metric, Audio, Dataset
|
3 |
-
from transformers import pipeline, AutoFeatureExtractor
|
|
|
4 |
import re
|
5 |
import argparse
|
6 |
import unicodedata
|
@@ -106,18 +107,29 @@ def main(args):
|
|
106 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
107 |
|
108 |
# for testing: only process the first two examples as a test
|
109 |
-
|
|
|
110 |
|
111 |
-
# load processor
|
112 |
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
|
|
113 |
sampling_rate = feature_extractor.sampling_rate
|
114 |
|
115 |
# resample audio
|
116 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
120 |
|
|
|
|
|
|
|
121 |
# map function to decode audio
|
122 |
def map_to_pred(batch):
|
123 |
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
@@ -158,6 +170,9 @@ if __name__ == "__main__":
|
|
158 |
parser.add_argument(
|
159 |
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
160 |
)
|
|
|
|
|
|
|
161 |
args = parser.parse_args()
|
162 |
|
163 |
main(args)
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
from datasets import load_dataset, load_metric, Audio, Dataset
|
3 |
+
from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2ForCTC
|
4 |
+
import os
|
5 |
import re
|
6 |
import argparse
|
7 |
import unicodedata
|
|
|
107 |
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
108 |
|
109 |
# for testing: only process the first two examples as a test
|
110 |
+
if args.limit:
|
111 |
+
dataset = dataset.select(range(limit))
|
112 |
|
|
|
113 |
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
114 |
+
# load processor
|
115 |
sampling_rate = feature_extractor.sampling_rate
|
116 |
|
117 |
# resample audio
|
118 |
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
119 |
+
|
120 |
+
asr = None
|
121 |
+
|
122 |
+
if os.path.exists(args.model_id):
|
123 |
+
model = Wav2Vec2ForCTC.from_pretrained(args.model_id)
|
124 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
|
125 |
+
|
126 |
|
127 |
+
# load eval pipeline
|
128 |
+
asr = pipeline("automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
129 |
|
130 |
+
else:
|
131 |
+
asr = pipeline("automatic-speech-recognition", model=args.model_id)
|
132 |
+
|
133 |
# map function to decode audio
|
134 |
def map_to_pred(batch):
|
135 |
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
|
|
170 |
parser.add_argument(
|
171 |
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
172 |
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--limit", type=int, help="Not required. If greater than zero, select a subset of this size from the dataset.", default=0
|
175 |
+
)
|
176 |
args = parser.parse_args()
|
177 |
|
178 |
main(args)
|
language_model/attrs.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"alpha": 0.
|
|
|
1 |
+
{"alpha": 0.9, "beta": 2.5, "unk_score_offset": -10.0, "score_boundary": true}
|
train.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocab.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "
|
|
|
1 |
+
{"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 6, "g": 7, "h": 8, "i": 9, "j": 10, "k": 11, "l": 12, "m": 13, "n": 14, "o": 15, "p": 16, "q": 17, "r": 18, "s": 19, "t": 20, "u": 21, "v": 22, "w": 23, "x": 24, "y": 25, "z": 26, "\u00e1": 27, "\u00e9": 28, "\u00ed": 29, "\u00f3": 30, "\u00fa": 31, "\u00fd": 32, "\u010d": 33, "\u010f": 34, "\u011b": 35, "\u0148": 36, "\u0159": 37, "\u0161": 38, "\u0165": 39, "\u016f": 40, "\u017e": 41, "|": 0, "[UNK]": 42, "[PAD]": 43}
|