anuragshas
commited on
Commit
•
5e5e5a8
1
Parent(s):
ef328b9
Upload lm-boosted decoder
Browse files- .gitattributes +1 -0
- .ipynb_checkpoints/run-checkpoint.sh +0 -33
- .ipynb_checkpoints/vocab-checkpoint.json +0 -1
- alphabet.json +1 -0
- eval.py +164 -0
- language_model/4gram.bin +3 -0
- language_model/attrs.json +1 -0
- language_model/unigrams.txt +3 -0
- preprocessor_config.json +1 -0
- special_tokens_map.json +1 -1
- tokenizer_config.json +1 -1
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
language_model/unigrams.txt filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/run-checkpoint.sh
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
python run_speech_recognition_ctc.py \
|
2 |
-
--dataset_name="mozilla-foundation/common_voice_8_0" \
|
3 |
-
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
4 |
-
--dataset_config_name="mr" \
|
5 |
-
--output_dir="./" \
|
6 |
-
--overwrite_output_dir \
|
7 |
-
--num_train_epochs="500" \
|
8 |
-
--per_device_train_batch_size="32" \
|
9 |
-
--per_device_eval_batch_size="16" \
|
10 |
-
--learning_rate="7.5e-5" \
|
11 |
-
--warmup_steps="2000" \
|
12 |
-
--length_column_name="input_length" \
|
13 |
-
--evaluation_strategy="steps" \
|
14 |
-
--text_column_name="sentence" \
|
15 |
-
--save_steps="400" \
|
16 |
-
--eval_steps="400" \
|
17 |
-
--logging_steps="100" \
|
18 |
-
--layerdrop="0.0" \
|
19 |
-
--activation_dropout="0.1" \
|
20 |
-
--save_total_limit="1" \
|
21 |
-
--freeze_feature_encoder \
|
22 |
-
--feat_proj_dropout="0.0" \
|
23 |
-
--mask_time_prob="0.75" \
|
24 |
-
--mask_time_length="10" \
|
25 |
-
--mask_feature_prob="0.25" \
|
26 |
-
--mask_feature_length="64" \
|
27 |
-
--chars_to_ignore , ? . ! \- \; \: \" “ % ‘ ” � — ’ … – \' । \॔ \
|
28 |
-
--gradient_checkpointing \
|
29 |
-
--use_auth_token \
|
30 |
-
--fp16 \
|
31 |
-
--group_by_length \
|
32 |
-
--do_train --do_eval \
|
33 |
-
--push_to_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.ipynb_checkpoints/vocab-checkpoint.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"ँ": 1, "ं": 2, "ः": 3, "अ": 4, "आ": 5, "इ": 6, "ई": 7, "उ": 8, "ऊ": 9, "ऋ": 10, "ए": 11, "ऐ": 12, "ऑ": 13, "ओ": 14, "औ": 15, "क": 16, "ख": 17, "ग": 18, "घ": 19, "च": 20, "छ": 21, "ज": 22, "झ": 23, "ञ": 24, "ट": 25, "ठ": 26, "ड": 27, "ढ": 28, "ण": 29, "त": 30, "थ": 31, "द": 32, "ध": 33, "न": 34, "प": 35, "फ": 36, "ब": 37, "भ": 38, "म": 39, "य": 40, "र": 41, "ऱ": 42, "ल": 43, "ळ": 44, "व": 45, "श": 46, "ष": 47, "स": 48, "ह": 49, "़": 50, "ा": 51, "ि": 52, "ी": 53, "ु": 54, "ू": 55, "ृ": 56, "ॄ": 57, "ॅ": 58, "े": 59, "ै": 60, "ॉ": 61, "ॊ": 62, "ो": 63, "ौ": 64, "्": 65, "|": 0, "[UNK]": 66, "[PAD]": 67}
|
|
|
|
alphabet.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"labels": [" ", "\u0901", "\u0902", "\u0903", "\u0905", "\u0906", "\u0907", "\u0908", "\u0909", "\u090a", "\u090b", "\u090f", "\u0910", "\u0911", "\u0913", "\u0914", "\u0915", "\u0916", "\u0917", "\u0918", "\u091a", "\u091b", "\u091c", "\u091d", "\u091e", "\u091f", "\u0920", "\u0921", "\u0922", "\u0923", "\u0924", "\u0925", "\u0926", "\u0927", "\u0928", "\u092a", "\u092b", "\u092c", "\u092d", "\u092e", "\u092f", "\u0930", "\u0931", "\u0932", "\u0933", "\u0935", "\u0936", "\u0937", "\u0938", "\u0939", "\u093c", "\u093e", "\u093f", "\u0940", "\u0941", "\u0942", "\u0943", "\u0944", "\u0945", "\u0947", "\u0948", "\u0949", "\u094a", "\u094b", "\u094c", "\u094d", "\u2047", "", "<s>", "</s>"], "is_bpe": false}
|
eval.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import argparse
|
3 |
+
import re
|
4 |
+
import unicodedata
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from datasets import Audio, Dataset, load_dataset, load_metric
|
9 |
+
|
10 |
+
from transformers import AutoFeatureExtractor, pipeline
|
11 |
+
|
12 |
+
|
13 |
+
def log_results(result: Dataset, args: Dict[str, str]):
|
14 |
+
"""DO NOT CHANGE. This function computes and logs the result metrics."""
|
15 |
+
|
16 |
+
log_outputs = args.log_outputs
|
17 |
+
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
|
18 |
+
|
19 |
+
# load metric
|
20 |
+
wer = load_metric("wer")
|
21 |
+
cer = load_metric("cer")
|
22 |
+
|
23 |
+
# compute metrics
|
24 |
+
wer_result = wer.compute(
|
25 |
+
references=result["target"], predictions=result["prediction"]
|
26 |
+
)
|
27 |
+
cer_result = cer.compute(
|
28 |
+
references=result["target"], predictions=result["prediction"]
|
29 |
+
)
|
30 |
+
|
31 |
+
# print & log results
|
32 |
+
result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
|
33 |
+
print(result_str)
|
34 |
+
|
35 |
+
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
36 |
+
f.write(result_str)
|
37 |
+
|
38 |
+
# log all results in text file. Possibly interesting for analysis
|
39 |
+
if log_outputs is not None:
|
40 |
+
pred_file = f"log_{dataset_id}_predictions.txt"
|
41 |
+
target_file = f"log_{dataset_id}_targets.txt"
|
42 |
+
|
43 |
+
with open(pred_file, "w") as p, open(target_file, "w") as t:
|
44 |
+
|
45 |
+
# mapping function to write output
|
46 |
+
def write_to_file(batch, i):
|
47 |
+
p.write(f"{i}" + "\n")
|
48 |
+
p.write(batch["prediction"] + "\n")
|
49 |
+
t.write(f"{i}" + "\n")
|
50 |
+
t.write(batch["target"] + "\n")
|
51 |
+
|
52 |
+
result.map(write_to_file, with_indices=True)
|
53 |
+
|
54 |
+
|
55 |
+
def normalize_text(text: str) -> str:
|
56 |
+
"""DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
|
57 |
+
|
58 |
+
chars_to_ignore_regex = """[\,\?\.\!\-\;\:\"\“\%\‘\”\�\—\’\…\–\'\।\॔]""" # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
|
59 |
+
text = unicodedata.normalize("NFKC", text)
|
60 |
+
text = re.sub(chars_to_ignore_regex, "", text.lower())
|
61 |
+
|
62 |
+
# In addition, we can normalize the target text, e.g. removing new lines characters etc...
|
63 |
+
# note that order is important here!
|
64 |
+
token_sequences_to_ignore = ["\n\n", "\n", " ", " "]
|
65 |
+
|
66 |
+
for t in token_sequences_to_ignore:
|
67 |
+
text = " ".join(text.split(t))
|
68 |
+
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
def main(args):
|
73 |
+
# load dataset
|
74 |
+
dataset = load_dataset(
|
75 |
+
args.dataset, args.config, split=args.split, use_auth_token=True
|
76 |
+
)
|
77 |
+
|
78 |
+
# for testing: only process the first two examples as a test
|
79 |
+
# dataset = dataset.select(range(10))
|
80 |
+
|
81 |
+
# load processor
|
82 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
83 |
+
sampling_rate = feature_extractor.sampling_rate
|
84 |
+
|
85 |
+
# resample audio
|
86 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
87 |
+
|
88 |
+
# load eval pipeline
|
89 |
+
if args.device is None:
|
90 |
+
args.device = 0 if torch.cuda.is_available() else -1
|
91 |
+
asr = pipeline(
|
92 |
+
"automatic-speech-recognition", model=args.model_id, device=args.device
|
93 |
+
)
|
94 |
+
|
95 |
+
# map function to decode audio
|
96 |
+
def map_to_pred(batch):
|
97 |
+
prediction = asr(
|
98 |
+
batch["audio"]["array"],
|
99 |
+
chunk_length_s=args.chunk_length_s,
|
100 |
+
stride_length_s=args.stride_length_s,
|
101 |
+
)
|
102 |
+
|
103 |
+
batch["prediction"] = prediction["text"]
|
104 |
+
batch["target"] = normalize_text(batch["sentence"])
|
105 |
+
return batch
|
106 |
+
|
107 |
+
# run inference on all examples
|
108 |
+
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
|
109 |
+
|
110 |
+
# compute and log_results
|
111 |
+
# do not change function below
|
112 |
+
log_results(result, args)
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
parser = argparse.ArgumentParser()
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--model_id",
|
120 |
+
type=str,
|
121 |
+
required=True,
|
122 |
+
help="Model identifier. Should be loadable with 🤗 Transformers",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--dataset",
|
126 |
+
type=str,
|
127 |
+
required=True,
|
128 |
+
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--config",
|
132 |
+
type=str,
|
133 |
+
required=True,
|
134 |
+
help="Config of the dataset. *E.g.* `'en'` for Common Voice",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--chunk_length_s",
|
141 |
+
type=float,
|
142 |
+
default=None,
|
143 |
+
help="Chunk length in seconds. Defaults to 5 seconds.",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--stride_length_s",
|
147 |
+
type=float,
|
148 |
+
default=None,
|
149 |
+
help="Stride of the audio chunks. Defaults to 1 second.",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--log_outputs",
|
153 |
+
action="store_true",
|
154 |
+
help="If defined, write outputs to log file for analysis.",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--device",
|
158 |
+
type=int,
|
159 |
+
default=None,
|
160 |
+
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
|
161 |
+
)
|
162 |
+
args = parser.parse_args()
|
163 |
+
|
164 |
+
main(args)
|
language_model/4gram.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4638564096b5800133f9de231905227633addc987b5a5e5ee7c2e3e64892e802
|
3 |
+
size 5301634441
|
language_model/attrs.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"alpha": 0.5, "beta": 1.5, "unk_score_offset": -10.0, "score_boundary": true}
|
language_model/unigrams.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:475211592ab58f3029e5120c4fb87960d5e909472c2e2fc948393275a4b3d44a
|
3 |
+
size 116468766
|
preprocessor_config.json
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
"feature_size": 1,
|
5 |
"padding_side": "right",
|
6 |
"padding_value": 0,
|
|
|
7 |
"return_attention_mask": true,
|
8 |
"sampling_rate": 16000
|
9 |
}
|
|
|
4 |
"feature_size": 1,
|
5 |
"padding_side": "right",
|
6 |
"padding_value": 0,
|
7 |
+
"processor_class": "Wav2Vec2ProcessorWithLM",
|
8 |
"return_attention_mask": true,
|
9 |
"sampling_rate": 16000
|
10 |
}
|
special_tokens_map.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
|
tokenizer_config.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "name_or_path": "
|
|
|
1 |
+
{"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "name_or_path": "anuragshas/wav2vec2-xls-r-300m-mr-cv8-with-lm", "tokenizer_class": "Wav2Vec2CTCTokenizer", "processor_class": "Wav2Vec2ProcessorWithLM"}
|