YSU
/

lilitket commited on
Commit
5f1c16f
·
1 Parent(s): aa3e1cb

Source Files

Browse files
Files changed (6) hide show
  1. cleaning.py +15 -0
  2. collator.py +90 -0
  3. compute_wer.py +179 -0
  4. fine_tune.py +357 -0
  5. lm_fusion.py +56 -0
  6. utils.py +62 -0
cleaning.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import fire
3
+
4
+ from aspram.utils import clean_characters
5
+
6
+
7
+ def exec(lower: bool = False, only_mesropatar: bool = False):
8
+ for line in sys.stdin:
9
+ line = line.strip()
10
+ line = clean_characters(dict(sentence=line), lower=lower, only_mesropatar=only_mesropatar)['sentence']
11
+ sys.stdout.write(line + "\n")
12
+
13
+
14
+ if __name__ == '__main__':
15
+ fire.Fire(exec)
collator.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from transformers import Wav2Vec2Processor
7
+
8
+ from torch_audiomentations import Compose, Gain
9
+ from audiomentations import (
10
+ Compose,
11
+ AddGaussianNoise,
12
+ AddGaussianSNR,
13
+ ClippingDistortion,
14
+ FrequencyMask,
15
+ Gain,
16
+ LoudnessNormalization,
17
+ Normalize,
18
+ PitchShift,
19
+ PolarityInversion,
20
+ Shift,
21
+ TimeMask,
22
+ TimeStretch,
23
+ )
24
+
25
+
26
+ class DataCollatorCTCWithPadding:
27
+
28
+ def __init__(
29
+ self,
30
+ processor: Wav2Vec2Processor,
31
+ padding: Union[bool, str] = True,
32
+ sample_rate: int = 16_000,
33
+ apply_gaussian_noise_with_p: float = 0,
34
+ apply_gain_with_p: float = 0,
35
+ apply_pitch_shift_with_p: float = 0,
36
+ apply_time_stretch_with_p: float = 0,
37
+ ):
38
+ self.processor = processor
39
+ self.padding = padding
40
+ self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p
41
+ self.apply_gain_with_p = apply_gain_with_p
42
+ self.apply_pitch_shift_with_p = apply_pitch_shift_with_p
43
+ self.apply_time_stretch_with_p = apply_time_stretch_with_p
44
+ self.sample_rate = sample_rate
45
+
46
+ self.augmentator = None
47
+ if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0:
48
+ self.augmentator = Compose([
49
+ TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p),
50
+ PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p),
51
+ Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p),
52
+ AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p),
53
+ ])
54
+
55
+ def _apply_augmentation(self, input_values: List[float]):
56
+ """apply some audio augmentations in the given input_values"""
57
+ if self.augmentator is not None:
58
+ return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist()
59
+ else:
60
+ return input_values
61
+
62
+ def __call__(
63
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
64
+ ) -> Dict[str, torch.Tensor]:
65
+ # TODO maybe disable augmentation in inference mode?
66
+ input_features = [
67
+ {"input_values": self._apply_augmentation(feature["input_values"])} for feature in features
68
+ ]
69
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
70
+
71
+ batch = self.processor.pad(
72
+ input_features,
73
+ padding=self.padding,
74
+ return_tensors="pt",
75
+ )
76
+ with self.processor.as_target_processor():
77
+ labels_batch = self.processor.pad(
78
+ label_features,
79
+ padding=self.padding,
80
+ return_tensors="pt",
81
+ )
82
+
83
+ # replace padding with -100 to ignore loss correctly
84
+ labels = labels_batch["input_ids"].masked_fill(
85
+ labels_batch.attention_mask.ne(1), -100
86
+ )
87
+
88
+ batch["labels"] = labels
89
+
90
+ return batch
compute_wer.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import weakref
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+ from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2Processor
7
+ from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ForCTC
8
+
9
+ from datasets import load_dataset, load_metric, Audio
10
+
11
+ import fire
12
+
13
+ from aspram.utils import clean_characters, prepare_dataset
14
+
15
+ # import sentencepiece as spm
16
+
17
+ # repo_name = "20220414-210228_lm"
18
+ # repo_name = "./20220414-210228_lm_spm_bpe"
19
+ def exec(
20
+ *,
21
+ repo_name: str,
22
+ dataset: str = "yerevann/common_voice_9_0",
23
+ cuda: bool = True,
24
+ batch_size: int = 8,
25
+ beam_width: int = 1,
26
+ j: int = 1,
27
+ sample_rate: int = 16_000,
28
+ alpha: float = None,
29
+ beta: float = None,
30
+ unk_score_offset: float = None,
31
+ lm_score_boundary: bool = None,
32
+ beam_prune_logp: float = None,
33
+ token_min_logp: float = None,
34
+ output_file : str = None,
35
+ ):
36
+
37
+ # repo_name = "20220428-094209--72000_lm"
38
+
39
+ print(f'loading model {repo_name}')
40
+ model = Wav2Vec2ForCTC.from_pretrained(repo_name)
41
+ print('done')
42
+ if cuda:
43
+ print('CUDA mode')
44
+ model.cuda()
45
+
46
+ if repo_name.endswith('_lm'):
47
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(repo_name, sample_rate=sample_rate)
48
+ with_lm = True
49
+ else:
50
+ processor = Wav2Vec2Processor.from_pretrained(repo_name, sample_rate=sample_rate)
51
+ with_lm = False
52
+
53
+ common_voice_test = load_dataset(
54
+ dataset,
55
+ "hy-AM",
56
+ split="test",
57
+ use_auth_token=True,
58
+ )
59
+ common_voice_test = common_voice_test.map(clean_characters)
60
+ common_voice_test = common_voice_test.cast_column(
61
+ "audio", Audio(sampling_rate=sample_rate)
62
+ )
63
+ common_voice_test = common_voice_test.map(
64
+ prepare_dataset,
65
+ remove_columns=common_voice_test.column_names,
66
+ fn_kwargs=dict(processor=processor)
67
+ )
68
+
69
+
70
+ # wer_metric = load()...
71
+ # for batch in batched_dataset:
72
+ # input_dict = processer(batch)
73
+ # logits = model(input...)
74
+ # wer_metric.update(true, pred)
75
+ # wer_metric.compute
76
+
77
+ # def exec_cer_wer(batch_size: int = 8, **kwargs):
78
+ def predict(batch):
79
+ # print(1)
80
+ input_dict = processor(
81
+ batch["input_values"],
82
+ return_tensors="pt",
83
+ padding=True,
84
+ sampling_rate=sample_rate
85
+ )
86
+ # print(2)
87
+ with torch.no_grad():
88
+ x = input_dict.input_values
89
+ if cuda:
90
+ x = x.cuda()
91
+ logits = model(x).logits
92
+ # print(3)
93
+ if with_lm:
94
+ # print(beam_size)
95
+ # sp = spm.SentencePieceProcessor()
96
+ # sp.load('head_mes_lower_bpe.model')
97
+
98
+ pred = processor.batch_decode(
99
+ logits.cpu().numpy(),
100
+ beam_width=beam_width,
101
+ alpha=alpha,
102
+ beta=beta,
103
+ unk_score_offset=unk_score_offset,
104
+ lm_score_boundary=lm_score_boundary,
105
+ num_processes=j,
106
+ beam_prune_logp=beam_prune_logp, #-1000,
107
+ token_min_logp=token_min_logp,
108
+ # sp=sp,
109
+ ).text
110
+ else:
111
+ pred = processor.batch_decode(
112
+ logits.cpu().numpy().argmax(-1),
113
+ )
114
+ # print(pred)
115
+ # print(pred)
116
+
117
+ return {
118
+ 'sentence': pred
119
+ }
120
+
121
+ with_predictions = common_voice_test.map(predict, batched=True, batch_size=batch_size)
122
+
123
+ def detokenize(sample):
124
+ if '▁' in sample['sentence']:
125
+ print("------ ", sample)
126
+ sample['sentence'] = sample['sentence'].replace(' ', '').replace('▁', ' ')
127
+ print("------ ", sample)
128
+ return sample
129
+
130
+ with_predictions = with_predictions.map(detokenize)
131
+
132
+ common_voice_test_transcription = load_dataset(
133
+ dataset,
134
+ "hy-AM",
135
+ split="test",
136
+ use_auth_token=True,
137
+ )
138
+
139
+ with_predictions = with_predictions.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True))
140
+ common_voice_test_transcription = common_voice_test_transcription.map(clean_characters, fn_kwargs=dict(lower=True, only_mesropatar=True))
141
+
142
+ predictions = with_predictions['sentence']
143
+ references = common_voice_test_transcription['sentence']
144
+
145
+ wer_metric = load_metric("wer")
146
+ cer_metric = load_metric("cer")
147
+
148
+ for ref, pred in zip(references, predictions):
149
+ print(f' REF:\t{ref}')
150
+ print(f'PRED:\t{pred}')
151
+ print('\n')
152
+
153
+ wer = wer_metric.compute(predictions=predictions, references=references)
154
+ cer = cer_metric.compute(predictions=predictions, references=references)
155
+ print("wer: ", wer)
156
+ print("cer: ", cer)
157
+
158
+ df = common_voice_test_transcription.to_pandas()['sentence']
159
+ df = df.to_frame()
160
+ df["predictions"] = with_predictions.to_pandas()['sentence']
161
+
162
+ # df.insert(2, "predictions", with_predictions['sentence'], True)
163
+
164
+ if output_file is not None:
165
+ df.to_csv(output_file)
166
+
167
+ # exec_cer_wer(beam_width=beam_width, batch_size=batch_size)
168
+
169
+ # for pruning_score in {-10, -100, -2000}:
170
+ # for alpha in {1, 0.5, 1.5}:
171
+ # for beta in {1, 0.5, 1.5}:
172
+ # for beam_size in {0, 2, 4, 6}:
173
+ # print("Configuration:")
174
+ # print("alpha {alpha} beta {beta}, beam_width {beam_size}, pruning_score {pruning_score}".format(alpha = alpha, beta = beta, beam_size = beam_size, pruning_score = pruning_score))
175
+ # exec_cer_wer(alpha, beta, 2**beam_size, pruning_score, batch_size=batch_size)
176
+ # print('\n\n')
177
+
178
+ if __name__ == "__main__":
179
+ fire.Fire(exec)
fine_tune.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import os
4
+ import json
5
+ import time
6
+
7
+ import numpy as np
8
+
9
+ from transformers import Trainer
10
+ from transformers import Wav2Vec2ForCTC
11
+ from transformers import TrainingArguments
12
+ from transformers import Wav2Vec2Processor
13
+ from transformers import Wav2Vec2CTCTokenizer
14
+ from transformers import Wav2Vec2FeatureExtractor
15
+
16
+ from datasets import load_dataset, load_metric, Audio, concatenate_datasets, load_from_disk
17
+
18
+ from aim import Run
19
+ from aim.hugging_face import AimCallback
20
+
21
+ import fire
22
+
23
+ from aspram.collator import DataCollatorCTCWithPadding
24
+ from aspram.utils import clean_characters, extract_all_chars, prepare_dataset
25
+
26
+
27
+ def load_data(dataset_name: str, *, split: str):
28
+ dataset_name = dataset_name.replace(' ', '')
29
+
30
+ if '+' in dataset_name:
31
+ return concatenate_datasets([
32
+ load_data(name, split=split)
33
+ for name in dataset_name.split('+')
34
+ ])
35
+
36
+ if '*' in dataset_name:
37
+ a, _, b = dataset_name.partition('*')
38
+ if a.isnumeric():
39
+ num_repeats = int(a)
40
+ dataset_name = b
41
+ else:
42
+ num_repeats = int(b)
43
+ dataset_name = a
44
+
45
+ dataset = load_data(dataset_name, split=split)
46
+
47
+ return concatenate_datasets([
48
+ dataset
49
+ for _ in range(num_repeats)
50
+ ])
51
+
52
+ if 'teacher' in dataset_name:
53
+ dataset = load_from_disk(
54
+ dataset_name,
55
+ ).filter(
56
+ lambda sample: len(sample['audio']['array']) < 250_000
57
+ )
58
+ elif 'common_voice' in dataset_name:
59
+ dataset = load_dataset(
60
+ dataset_name,
61
+ "hy-AM",
62
+ split="train+validation+other" if split == 'train' else split,
63
+ use_auth_token=True,
64
+ )
65
+ else:
66
+ dataset = load_dataset(
67
+ dataset_name,
68
+ 'hy_am',
69
+ split='train',
70
+ ).map(
71
+ lambda sample: dict(sentence=sample['transcription'])
72
+ ).filter(
73
+ lambda sample: sample['num_samples'] < 250_000
74
+ )
75
+
76
+ non_wanted_column_name = set(dataset.column_names) - set(['audio', 'path', 'sentence', 'client_id'])
77
+
78
+ dataset = dataset.map(remove_columns=non_wanted_column_name).cast_column("audio", Audio(sampling_rate=16_000))
79
+
80
+ return dataset
81
+
82
+
83
+ def exec(
84
+ *,
85
+ batch_size: int,
86
+ lr: float,
87
+ warmup_steps: int = 2000,
88
+ grad_acc: int = 1,
89
+ group_by_length: bool = True,
90
+ fp16: bool = True,
91
+ bf16: bool = False,
92
+ pretrained_model: str = "facebook/wav2vec2-xls-r-2b",
93
+ dataset: str = "mozilla-foundation/common_voice_8_0",
94
+ num_train_epochs: int = 1200,
95
+ blacklist_enabled: bool = True,
96
+ seed: int = 42,
97
+ # random augment
98
+ apply_gaussian_noise_with_p: float = 0,
99
+ apply_gain_with_p: float = 0,
100
+ apply_pitch_shift_with_p: float = 0,
101
+ apply_time_stretch_with_p: float = 0,
102
+ # spec augment
103
+ mask_time_prob: float = 0.05, # value that is used in the previous models
104
+ mask_time_length: int = 10,
105
+ mask_time_min_masks: int = 2,
106
+ mask_feature_prob: float = 0,
107
+ mask_feature_length: int = 10,
108
+ mask_feature_min_masks: int = 0,
109
+
110
+ layerdrop: float = 0,
111
+ activation_dropout: float = 0.1,
112
+
113
+ lower: bool = False,
114
+ only_mesropatar: bool = False,
115
+ gradient_checkpointing: bool = False,
116
+ resume_from_hash: str = None,
117
+ ):
118
+ if bf16:
119
+ fp16 = False
120
+ fire_args = locals()
121
+
122
+ run = Run(resume_from_hash, log_system_params=(not resume_from_hash))
123
+ if not resume_from_hash:
124
+ timestr = time.strftime("%Y%m%d-%H%M%S")
125
+ repo_name = os.path.join('models', timestr)
126
+ for key, value in fire_args.items():
127
+ run['hparams', key] = value
128
+ run['fire', key] = value
129
+ else:
130
+ repo_name = run['hparams', 'output_dir']
131
+ run_hash = run.hash
132
+ run = None
133
+
134
+
135
+ train_dataset = load_data(dataset, split="train")
136
+
137
+ blacklist_client_ids = set()
138
+ blacklist_sentences = set()
139
+
140
+ if blacklist_enabled:
141
+ blacklist_client_ids = {
142
+ "93fa435db2b9e077af647c9f846d8b6031bcb1f6cd731e894a835e70a0ab4aec1faffce01c882bdcdcb854b98b601c83a1c412bae8e5ee411556f0e2f88c1c5c",
143
+ "f0aba38a8ab8705a40d05d96829ded5738a7eec7a9a182394c2ed288fc1c64553abcb1e0c4c966ffab9e8b76c27616b9f0503f92c42fe11249af36c50d3de5ef",
144
+ "a528aa436a34dce3b4ddc198c105ebb904967acdd04157bd1b0e0b2ffadd99b36a6cc5fe76f23c3dd2263d1507bec6038c41cb521ac8ee34126133e559df9e75",
145
+ "b83375c41b8ef9ab1b64491b624302b1541b0ba8496ed4e5cb4a751766d7a2cf7430e49e7118eaac98f5ae478d8cdd2b59d18526632297185bbc2e10e2126b18",
146
+ "330411ed21c5d9cda96180ac633b4dd10f5b6e50968e83a64f0016c9e15f22445fa8f396ef92b70ff03fc78e36b35b1693af60431b61b50b706aa58a00f80641",
147
+ }
148
+
149
+ # valid_dataset = load_data(dataset, split="test")
150
+ valid_dataset = load_data("yerevann/common_voice_9_0", split="test")
151
+
152
+ # train_client_ids = set(train_dataset['client_id']) - { None }
153
+ valid_client_ids = set(valid_dataset['client_id']) - { None }
154
+ blacklist_sentences = set(valid_dataset['sentence'])
155
+ blacklist_client_ids |= valid_client_ids
156
+
157
+ train_dataset = train_dataset.filter(
158
+ lambda sample: (
159
+ sample.get("client_id") not in blacklist_client_ids
160
+ and
161
+ sample.get("sentence") not in blacklist_sentences
162
+ )
163
+ )
164
+
165
+ # print('\n' * 10 + '================================' + '\n' * 10)
166
+ # print(train_client_ids & valid_client_ids)
167
+ # print('\n' * 10 + '================================' + '\n' * 10)
168
+
169
+ # train_dataset = train_dataset.remove_columns(
170
+ # [
171
+ # "accent",
172
+ # "age",
173
+ # "client_id",
174
+ # "down_votes",
175
+ # "gender",
176
+ # "locale",
177
+ # "segment",
178
+ # "up_votes",
179
+ # ]
180
+ # )
181
+ # valid_dataset = valid_dataset.remove_columns(
182
+ # [
183
+ # "accent",
184
+ # "age",
185
+ # "client_id",
186
+ # "down_votes",
187
+ # "gender",
188
+ # "locale",
189
+ # "segment",
190
+ # "up_votes",
191
+ # ]
192
+ # )
193
+
194
+ train_dataset = train_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar))
195
+ valid_dataset = valid_dataset.map(clean_characters, fn_kwargs=dict(lower=lower, only_mesropatar=only_mesropatar))
196
+
197
+ if 'models/' in pretrained_model:
198
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model)
199
+ elif not resume_from_hash:
200
+ vocab_train = train_dataset.map(
201
+ extract_all_chars,
202
+ batched=True,
203
+ batch_size=-1,
204
+ keep_in_memory=True,
205
+ remove_columns=train_dataset.column_names,
206
+ )
207
+ vocab_valid = valid_dataset.map(
208
+ extract_all_chars,
209
+ batched=True,
210
+ batch_size=-1,
211
+ keep_in_memory=True,
212
+ remove_columns=valid_dataset.column_names,
213
+ )
214
+ vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_valid["vocab"][0]))
215
+ vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
216
+ vocab_dict["|"] = vocab_dict[" "]
217
+ del vocab_dict[" "]
218
+
219
+ vocab_dict["[UNK]"] = len(vocab_dict)
220
+ vocab_dict["[PAD]"] = len(vocab_dict)
221
+
222
+ with open("vocab.json", "w") as vocab_file:
223
+ json.dump(vocab_dict, vocab_file)
224
+
225
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
226
+ "./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
227
+ )
228
+ tokenizer.push_to_hub(repo_name) # smth is wrong here
229
+ else:
230
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(repo_name)
231
+
232
+ feature_extractor = Wav2Vec2FeatureExtractor(
233
+ feature_size=1,
234
+ sampling_rate=16000,
235
+ padding_value=0.0,
236
+ do_normalize=True,
237
+ return_attention_mask=True,
238
+ )
239
+ processor = Wav2Vec2Processor(
240
+ feature_extractor=feature_extractor,
241
+ tokenizer=tokenizer,
242
+ )
243
+
244
+
245
+ train_dataset = train_dataset.cast_column(
246
+ "audio", Audio(sampling_rate=16_000)
247
+ )
248
+ valid_dataset = valid_dataset.cast_column(
249
+ "audio", Audio(sampling_rate=16_000)
250
+ )
251
+
252
+ train_dataset = train_dataset.map(
253
+ prepare_dataset, remove_columns=train_dataset.column_names,
254
+ fn_kwargs=dict(processor=processor)
255
+ )
256
+ valid_dataset = valid_dataset.map(
257
+ prepare_dataset, remove_columns=valid_dataset.column_names,
258
+ fn_kwargs=dict(processor=processor)
259
+ )
260
+
261
+ data_collator = DataCollatorCTCWithPadding(
262
+ processor=processor,
263
+ padding=True,
264
+ sample_rate=16_000,
265
+ apply_gaussian_noise_with_p=apply_gaussian_noise_with_p,
266
+ apply_gain_with_p=apply_gain_with_p,
267
+ apply_pitch_shift_with_p=apply_pitch_shift_with_p,
268
+ apply_time_stretch_with_p=apply_time_stretch_with_p,
269
+ )
270
+
271
+ def compute_metrics(pred):
272
+ pred_logits = pred.predictions
273
+ pred_ids = np.argmax(pred_logits, axis=-1)
274
+
275
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
276
+
277
+ pred_str = processor.batch_decode(pred_ids)
278
+ # we do not want to group tokens when computing the metrics
279
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
280
+
281
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
282
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
283
+
284
+ return {"wer": wer, "cer": cer}
285
+
286
+ wer_metric = load_metric("wer")
287
+ cer_metric = load_metric("cer")
288
+
289
+ def model_init():
290
+ from transformers import Wav2Vec2Config
291
+ model = Wav2Vec2ForCTC.from_pretrained(
292
+ pretrained_model,
293
+ attention_dropout=0.0,
294
+ hidden_dropout=0.0,
295
+ feat_proj_dropout=0.0,
296
+ mask_time_prob=mask_time_prob,
297
+ mask_time_length=mask_time_length,
298
+ mask_time_min_masks=mask_time_min_masks,
299
+ mask_feature_prob=mask_feature_prob,
300
+ mask_feature_length=mask_feature_length,
301
+ mask_feature_min_masks=mask_feature_min_masks,
302
+ layerdrop=layerdrop,
303
+ activation_dropout=activation_dropout,
304
+ ctc_loss_reduction="mean",
305
+ pad_token_id=processor.tokenizer.pad_token_id,
306
+ vocab_size=len(processor.tokenizer),
307
+ )
308
+ model.freeze_feature_extractor()
309
+ return model
310
+
311
+ training_args = TrainingArguments(
312
+ output_dir=repo_name,
313
+ group_by_length=group_by_length,
314
+ per_device_train_batch_size=batch_size,
315
+ gradient_accumulation_steps=grad_acc,
316
+ evaluation_strategy="steps",
317
+ num_train_epochs=num_train_epochs,
318
+ gradient_checkpointing=gradient_checkpointing if resume_from_hash is None else True,
319
+ fp16=fp16,
320
+ bf16=bf16,
321
+ save_steps=4000,
322
+ eval_steps=200,
323
+ logging_steps=200,
324
+ learning_rate=lr, # TODO
325
+ warmup_steps=warmup_steps,
326
+ save_total_limit=1,
327
+ push_to_hub=True,
328
+ metric_for_best_model="eval_wer",
329
+ greater_is_better=False,
330
+ seed=seed,
331
+ )
332
+
333
+ aim_callback = AimCallback()
334
+ aim_callback._run_hash = run_hash
335
+
336
+
337
+ print(train_dataset)
338
+ # run = aim_callback.experiment
339
+
340
+ trainer = Trainer(
341
+ model_init=model_init,
342
+ data_collator=data_collator,
343
+ args=training_args,
344
+ compute_metrics=compute_metrics,
345
+ train_dataset=train_dataset,
346
+ eval_dataset=valid_dataset,
347
+ tokenizer=processor.feature_extractor,
348
+ callbacks=[aim_callback],
349
+ )
350
+
351
+ trainer.train(resume_from_checkpoint=bool(resume_from_hash))
352
+
353
+ trainer.push_to_hub()
354
+
355
+
356
+ if __name__ == "__main__":
357
+ fire.Fire(exec)
lm_fusion.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor
2
+ from transformers import Wav2Vec2ProcessorWithLM
3
+
4
+ from pyctcdecode import build_ctcdecoder
5
+
6
+ from huggingface_hub import Repository
7
+
8
+ import logging
9
+
10
+ import fire
11
+
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def exec(
18
+ kenlm_model_path: str,
19
+ model_name: str,
20
+ lm_model_name: str = "",
21
+ ):
22
+ if not lm_model_name:
23
+ lm_model_name = model_name + "_lm"
24
+ logger.info(f'writing on {lm_model_name}')
25
+ logger.info(f'loading processor of `{model_name}`')
26
+ processor = AutoProcessor.from_pretrained(model_name)
27
+ logger.info(f'done loading `{model_name}`')
28
+
29
+ vocab_dict = processor.tokenizer.get_vocab()
30
+ sorted_vocab_dict = {
31
+ k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])
32
+ }
33
+
34
+ logger.info(f'building ctc decoder from {kenlm_model_path}')
35
+ decoder = build_ctcdecoder(
36
+ labels=list(sorted_vocab_dict.keys()),
37
+ kenlm_model_path=kenlm_model_path,
38
+ )
39
+ logger.info('done')
40
+
41
+ processor_with_lm = Wav2Vec2ProcessorWithLM(
42
+ feature_extractor=processor.feature_extractor,
43
+ tokenizer=processor.tokenizer,
44
+ decoder=decoder,
45
+ )
46
+
47
+ # repo = Repository(
48
+ # local_dir=lm_model_name, clone_from=model_name
49
+ # ) # model_name
50
+ # repo.push_to_hub()
51
+
52
+ processor_with_lm.save_pretrained(lm_model_name)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ fire.Fire(exec)
utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def clean_characters(sample, lower: bool = False, only_mesropatar: bool = False):
4
+
5
+ if 'sentence' not in sample:
6
+ if 'transcription' not in sample:
7
+ raise NotImplementedError()
8
+ else:
9
+ sample['sentence'] = sample['transcription']
10
+
11
+ allowed_chars = (
12
+ "-"
13
+ "a-z"
14
+ "A-Z"
15
+ "0-9"
16
+ "ԱԲԳԴԵԶԷԸԹԺԻԼԽԾԿՀՁՂՃՄՅՆՇՈՉՊՋՌՍՎՏՐՑՒՓՔՕՖ"
17
+ "աբգդեզէըթժիլխծկհձղճմյնշոչպջռսվտրցւփքօֆև"
18
+ " \"'։֊.:?;,ՙ՚՛՜՝՞՟\(\)"
19
+ )
20
+ if lower:
21
+ sample["sentence"] = sample["sentence"].lower()
22
+
23
+ if only_mesropatar:
24
+ allowed_chars = (
25
+ "ԱԲԳԴԵԶԷԸԹԺԻԼԽԾԿՀՁՂՃՄՅՆՇՈՉՊՋՌՍՎՏՐՑՒՓՔՕՖ"
26
+ "աբգդեզէըթժիլխծկհձղճմյնշոչպջռսվտրցւփքօֆև"
27
+ " -"
28
+ )
29
+ sample["sentence"] = re.sub(f"[^{allowed_chars}]", "", sample["sentence"])
30
+ # print(sample["sentence"])
31
+ return sample
32
+
33
+ def extract_all_chars(batch):
34
+ all_text = " ".join(batch["sentence"])
35
+ vocab = list(set(all_text))
36
+ return {"vocab": [vocab], "all_text": [all_text]}
37
+
38
+ def prepare_dataset(smaple, processor):
39
+ audio = smaple["audio"]
40
+
41
+ smaple["input_values"] = processor(
42
+ audio["array"], sampling_rate=audio["sampling_rate"]
43
+ ).input_values[0]
44
+ smaple["input_length"] = len(smaple["input_values"])
45
+
46
+ with processor.as_target_processor():
47
+ smaple["labels"] = processor(smaple["sentence"]).input_ids
48
+ return smaple
49
+
50
+
51
+ def batched_prepare_dataset(batch, processor):
52
+ batch = batch.copy()
53
+ audio = batch["audio"]
54
+
55
+ batch["input_values"] = processor(
56
+ [i["array"] for i in audio], sampling_rate=16_000
57
+ ).input_values
58
+ batch["input_length"] = [len(i) for i in batch["input_values"] ]
59
+
60
+ with processor.as_target_processor():
61
+ batch["labels"] = processor(batch["sentence"]).input_ids
62
+ return batch