File size: 15,898 Bytes
1e275bf
 
778e524
 
 
 
 
 
1e275bf
778e524
1e275bf
8d3fee9
d995c83
7bf0ac3
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb2b82e
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bbf33
 
778e524
 
 
 
 
7bf0ac3
2808233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778e524
 
8d3fee9
 
 
7c91348
8d3fee9
 
 
 
778e524
d995c83
 
 
 
 
7bf0ac3
d995c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778e524
 
 
cae4858
 
 
 
 
 
 
7bf0ac3
d995c83
1e275bf
 
 
d995c83
 
 
778e524
 
b839dd6
778e524
 
 
32440c9
778e524
cb2b82e
 
 
1e275bf
d995c83
b43fbd9
778e524
1e275bf
778e524
 
 
 
 
 
1e275bf
778e524
 
 
cbf9056
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
d995c83
 
778e524
d995c83
 
 
 
 
 
 
 
 
 
 
 
 
 
778e524
 
 
 
 
cb2b82e
 
 
 
 
 
 
 
d995c83
778e524
 
 
 
cb2b82e
 
 
 
 
 
 
1e275bf
2808233
 
 
1e275bf
d995c83
 
 
 
 
 
778e524
cbf9056
 
 
 
 
 
 
 
 
 
 
 
 
 
778e524
1e275bf
 
 
778e524
 
 
 
 
 
 
 
 
 
1e275bf
 
8d3fee9
d995c83
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
    TrainerCallback
from datasets import load_from_disk
from data_handler import DataCollatorCTCWithPadding
from transformers import TrainingArguments
from transformers import Trainer, logging
from metric_utils import compute_metrics_fn
from transformers.trainer_utils import get_last_checkpoint
import json
import os, glob
from callbacks import BreakEachEpoch
import subprocess
from multiprocessing import Process
import shutil

logging.set_verbosity_info()


def load_pretrained_model(checkpoint_path=None):
    if checkpoint_path is None:
        pre_trained_path = './model-bin/pretrained/base'
        tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
                                         unk_token="<unk>",
                                         pad_token="<pad>",
                                         word_delimiter_token="|")

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
        processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

        model = Wav2Vec2ForCTC.from_pretrained(
            pre_trained_path,
            gradient_checkpointing=True,
            ctc_loss_reduction="mean",
            pad_token_id=processor.tokenizer.pad_token_id,
        )
        model.freeze_feature_extractor()
    else:
        tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
        processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

        model = Wav2Vec2ForCTC.from_pretrained(
            checkpoint_path,
            gradient_checkpointing=True,
            ctc_loss_reduction="mean",
            pad_token_id=processor.tokenizer.pad_token_id,
        )
        # model.freeze_feature_extractor()

    # model = Wav2Vec2ForCTC(model.config)
    model_total_params = sum(p.numel() for p in model.parameters())
    model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(model)
    print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
                                                                            model_total_params_trainable))
    return model, processor


def prepare_dataset(batch, processor):
    # check that all files have the correct sampling rate
    assert (
            len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

    batch["length"] = [len(item) for item in batch["input_values"]]

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch


def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
    try:
      dataset = load_from_disk(path)
      list_cache_prefetch_files = glob.glob(
          cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
              '.arrow', '*'))

      # Do not re-compute what already in cache folder
      if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
          if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
                                                  cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
              return
          if len(list_cache_prefetch_files) > 0:
              return

      # check cache file
      if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
          for item_file in list_cache_prefetch_files:
              shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
                                                    cache_processing_dataset_folder))
      if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
          return dataset.map(prepare_dataset,
                            remove_columns=dataset.column_names,
                            batch_size=32,
                            num_proc=num_proc,
                            batched=True,
                            fn_kwargs={"processor": processor},
                            cache_file_name=cache_file_map_name)

      dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
                              batch_size=32,
                              num_proc=num_proc,
                              cache_file_name=cache_file_filter_name)
      processed_dataset = dataset.map(prepare_dataset,
                                      remove_columns=dataset.column_names,
                                      batch_size=32,
                                      num_proc=num_proc,
                                      batched=True,
                                      fn_kwargs={"processor": processor},
                                      cache_file_name=cache_file_map_name)
      processed_dataset.cleanup_cache_files()
      return processed_dataset
    except:
      return None


def commit_checkpoint():
    submit_commands = [
        'git add model-bin/finetune/base/*',
        'git commit -m "auto-commit"',
        'git push origin main'
    ]
    for command in submit_commands:
        print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))


def get_train_test_shard_id(epoch_count):
    # loop over training shards
    _train_dataset_shard_idx = epoch_count % num_train_shards
    # Get test shard depend on train shard id
    _test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1)
    _num_test_sub_shard = 8  # Split test shard into subset. Default is 8
    _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard  # loop over test shard subset
    return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard


def process_prefetch_epoch(epoch_count):
    train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count)
    load_prepared_dataset(os.path.join(train_dataset_root_folder,
                                       'shard_{}'.format(train_shard_idx)),
                          w2v_ctc_processor,
                          cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
                                                              'train',
                                                              'cache-train-filter-shard-{}.arrow'.format(
                                                                  train_shard_idx)),
                          cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch,
                                                           'train',
                                                           'cache-train-map-shard-{}.arrow'.format(
                                                               train_shard_idx)),
                          )
    load_prepared_dataset(os.path.join(test_dataset_root_folder,
                                       'shard_{}'.format(test_shard_idx)),
                          w2v_ctc_processor,
                          cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
                                                              'test',
                                                              'cache-test-filter-shard-{}.arrow'.format(
                                                                  test_shard_idx)),
                          cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test',
                                                           'cache-test-map-shard-{}.arrow'.format(
                                                               test_shard_idx))
                          )


if __name__ == "__main__":

    checkpoint_path = "./model-bin/finetune/base/"

    # train_dataset_root_folder = './data-bin/train_dataset'
    # test_dataset_root_folder = './data-bin/test_dataset'

    train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
    test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'

    cache_processing_dataset_folder = '/dev/shm/cache/'
    cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
    if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
        os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
        os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
    if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')):
        os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train'))
        os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test'))
    num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
    num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
    num_epochs = 5000

    training_args = TrainingArguments(
        output_dir=checkpoint_path,
        fp16=True,
        group_by_length=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        gradient_accumulation_steps=2,
        num_train_epochs=num_epochs,  # each epoch per shard data
        logging_steps=5,
        learning_rate=1e-5,
        weight_decay=0.005,
        warmup_steps=1000,
        save_total_limit=2,
        ignore_data_skip=True,
        logging_dir=os.path.join(checkpoint_path, 'log'),
        metric_for_best_model='wer',
        save_strategy="epoch",
        evaluation_strategy="epoch",
        greater_is_better=False,
        # save_steps=5,
        # eval_steps=5,
    )
    trainer = None

    # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
    last_checkpoint_path = None
    last_epoch_idx = 0
    if os.path.exists(checkpoint_path):
        last_checkpoint_path = get_last_checkpoint(checkpoint_path)
        if last_checkpoint_path is not None:
            with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
                trainer_state = json.load(file)
                last_epoch_idx = int(trainer_state['epoch'])

    w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
    data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)

    prefetch_process = []

    for epoch_idx in range(last_epoch_idx, num_epochs):
        # # loop over training shards
        # train_dataset_shard_idx = epoch_idx % num_train_shards
        # # Get test shard depend on train shard id
        # test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
        # num_test_sub_shard = 8  # Split test shard into subset. Default is 8
        # idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard  # loop over test shard subset

        train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id(
            epoch_idx)

        # waiting for all prefetch process done
        for process_instance in prefetch_process:
            process_instance.join()
        prefetch_process.clear()

        # load train shard
        train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
                                                           'shard_{}'.format(train_dataset_shard_idx)),
                                              w2v_ctc_processor,
                                              cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
                                                                                  'train',
                                                                                  'cache-train-filter-shard-{}.arrow'.format(
                                                                                      train_dataset_shard_idx)),
                                              cache_file_map_name=os.path.join(cache_processing_dataset_folder,
                                                                               'train',
                                                                               'cache-train-map-shard-{}.arrow'.format(
                                                                                   train_dataset_shard_idx)),
                                              )  # .shard(1000, 0)  # Remove shard split when train
        # load test shard subset
        test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
                                                          'shard_{}'.format(test_dataset_shard_idx)),
                                             w2v_ctc_processor,
                                             cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
                                                                                 'test',
                                                                                 'cache-test-filter-shard-{}.arrow'.format(
                                                                                     test_dataset_shard_idx)),
                                             cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
                                                                              'cache-test-map-shard-{}.arrow'.format(
                                                                                  test_dataset_shard_idx))
                                             )
        if train_dataset is None or test_dataset is None:
          print("Ignore Shard {}".format(train_dataset_shard_idx))
          continue
        test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)

        # Prefetch_dataset
        prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,)))
        for process_instance in prefetch_process:
            process_instance.start()

        # Init trainer
        if trainer is None:
            trainer = Trainer(
                model=w2v_ctc_model,
                data_collator=data_collator,
                args=training_args,
                compute_metrics=compute_metrics_fn(w2v_ctc_processor),
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                tokenizer=w2v_ctc_processor.feature_extractor,
                callbacks=[BreakEachEpoch()]  # Manual break end of epoch because each epoch loop over a shard
            )
        else:
            trainer.train_dataset = train_dataset
            trainer.eval_dataset = test_dataset

        logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
        logging.get_logger().info(
            'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))

        if last_checkpoint_path is not None:
            # start train from a checkpoint if exist
            trainer.train(resume_from_checkpoint=True)
        else:
            # train from pre-trained wav2vec2 checkpoint
            trainer.train()
        last_checkpoint_path = get_last_checkpoint(checkpoint_path)

        # Clear cache file to free disk
        test_dataset.cleanup_cache_files()
        train_dataset.cleanup_cache_files()

        if epoch_idx % 5 == 0:
            commit_checkpoint()