from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \ TrainerCallback from datasets import load_from_disk from data_handler import DataCollatorCTCWithPadding from transformers import TrainingArguments from transformers import Trainer, logging from metric_utils import compute_metrics_fn from transformers.trainer_utils import get_last_checkpoint import json import os, glob from callbacks import BreakEachEpoch import subprocess from multiprocessing import Process import shutil logging.set_verbosity_info() def load_pretrained_model(checkpoint_path=None): if checkpoint_path is None: pre_trained_path = './model-bin/pretrained/base' tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json", unk_token="", pad_token="", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) model = Wav2Vec2ForCTC.from_pretrained( pre_trained_path, gradient_checkpointing=True, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, ) model.freeze_feature_extractor() else: tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) model = Wav2Vec2ForCTC.from_pretrained( checkpoint_path, gradient_checkpointing=True, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, ) # model.freeze_feature_extractor() # model = Wav2Vec2ForCTC(model.config) model_total_params = sum(p.numel() for p in model.parameters()) model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(model) print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params, model_total_params_trainable)) return model, processor def prepare_dataset(batch, processor): # check that all files have the correct sampling rate assert ( len(set(batch["sampling_rate"])) == 1 ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values batch["length"] = [len(item) for item in batch["input_values"]] with processor.as_target_processor(): batch["labels"] = processor(batch["target_text"]).input_ids return batch def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5): try: dataset = load_from_disk(path) list_cache_prefetch_files = glob.glob( cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace( '.arrow', '*')) # Do not re-compute what already in cache folder if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch): if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch, cache_processing_dataset_folder).replace('.arrow', '*'))) > 0: return if len(list_cache_prefetch_files) > 0: return # check cache file if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0: for item_file in list_cache_prefetch_files: shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch, cache_processing_dataset_folder)) if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0: return dataset.map(prepare_dataset, remove_columns=dataset.column_names, batch_size=32, num_proc=num_proc, batched=True, fn_kwargs={"processor": processor}, cache_file_name=cache_file_map_name) dataset = dataset.filter(lambda example: len(example['speech']) < 160000, batch_size=32, num_proc=num_proc, cache_file_name=cache_file_filter_name) processed_dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names, batch_size=32, num_proc=num_proc, batched=True, fn_kwargs={"processor": processor}, cache_file_name=cache_file_map_name) processed_dataset.cleanup_cache_files() return processed_dataset except: return None def commit_checkpoint(): submit_commands = [ 'git add model-bin/finetune/base/*', 'git commit -m "auto-commit"', 'git push origin main' ] for command in submit_commands: print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8')) def get_train_test_shard_id(epoch_count): # loop over training shards _train_dataset_shard_idx = epoch_count % num_train_shards # Get test shard depend on train shard id _test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1) _num_test_sub_shard = 8 # Split test shard into subset. Default is 8 _idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard def process_prefetch_epoch(epoch_count): train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count) load_prepared_dataset(os.path.join(train_dataset_root_folder, 'shard_{}'.format(train_shard_idx)), w2v_ctc_processor, cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch, 'train', 'cache-train-filter-shard-{}.arrow'.format( train_shard_idx)), cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'train', 'cache-train-map-shard-{}.arrow'.format( train_shard_idx)), ) load_prepared_dataset(os.path.join(test_dataset_root_folder, 'shard_{}'.format(test_shard_idx)), w2v_ctc_processor, cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test', 'cache-test-filter-shard-{}.arrow'.format( test_shard_idx)), cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test', 'cache-test-map-shard-{}.arrow'.format( test_shard_idx)) ) if __name__ == "__main__": checkpoint_path = "./model-bin/finetune/base/" # train_dataset_root_folder = './data-bin/train_dataset' # test_dataset_root_folder = './data-bin/test_dataset' train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset' test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset' cache_processing_dataset_folder = '/dev/shm/cache/' cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/' if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')): os.makedirs(os.path.join(cache_processing_dataset_folder, 'train')) os.makedirs(os.path.join(cache_processing_dataset_folder, 'test')) if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')): os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train')) os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test')) num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*'))) num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*'))) num_epochs = 5000 training_args = TrainingArguments( output_dir=checkpoint_path, fp16=True, group_by_length=True, per_device_train_batch_size=32, per_device_eval_batch_size=32, gradient_accumulation_steps=2, num_train_epochs=num_epochs, # each epoch per shard data logging_steps=5, learning_rate=1e-5, weight_decay=0.005, warmup_steps=1000, save_total_limit=2, ignore_data_skip=True, logging_dir=os.path.join(checkpoint_path, 'log'), metric_for_best_model='wer', save_strategy="epoch", evaluation_strategy="epoch", greater_is_better=False, # save_steps=5, # eval_steps=5, ) trainer = None # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) last_checkpoint_path = None last_epoch_idx = 0 if os.path.exists(checkpoint_path): last_checkpoint_path = get_last_checkpoint(checkpoint_path) if last_checkpoint_path is not None: with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file: trainer_state = json.load(file) last_epoch_idx = int(trainer_state['epoch']) w2v_ctc_model, w2v_ctc_processor = load_pretrained_model() data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True) prefetch_process = [] for epoch_idx in range(last_epoch_idx, num_epochs): # # loop over training shards # train_dataset_shard_idx = epoch_idx % num_train_shards # # Get test shard depend on train shard id # test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards)) # num_test_sub_shard = 8 # Split test shard into subset. Default is 8 # idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id( epoch_idx) # waiting for all prefetch process done for process_instance in prefetch_process: process_instance.join() prefetch_process.clear() # load train shard train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder, 'shard_{}'.format(train_dataset_shard_idx)), w2v_ctc_processor, cache_file_filter_name=os.path.join(cache_processing_dataset_folder, 'train', 'cache-train-filter-shard-{}.arrow'.format( train_dataset_shard_idx)), cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'train', 'cache-train-map-shard-{}.arrow'.format( train_dataset_shard_idx)), ) # .shard(1000, 0) # Remove shard split when train # load test shard subset test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder, 'shard_{}'.format(test_dataset_shard_idx)), w2v_ctc_processor, cache_file_filter_name=os.path.join(cache_processing_dataset_folder, 'test', 'cache-test-filter-shard-{}.arrow'.format( test_dataset_shard_idx)), cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test', 'cache-test-map-shard-{}.arrow'.format( test_dataset_shard_idx)) ) if train_dataset is None or test_dataset is None: print("Ignore Shard {}".format(train_dataset_shard_idx)) continue test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard) # Prefetch_dataset prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,))) for process_instance in prefetch_process: process_instance.start() # Init trainer if trainer is None: trainer = Trainer( model=w2v_ctc_model, data_collator=data_collator, args=training_args, compute_metrics=compute_metrics_fn(w2v_ctc_processor), train_dataset=train_dataset, eval_dataset=test_dataset, tokenizer=w2v_ctc_processor.feature_extractor, callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard ) else: trainer.train_dataset = train_dataset trainer.eval_dataset = test_dataset logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards)) logging.get_logger().info( 'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard)) if last_checkpoint_path is not None: # start train from a checkpoint if exist trainer.train(resume_from_checkpoint=True) else: # train from pre-trained wav2vec2 checkpoint trainer.train() last_checkpoint_path = get_last_checkpoint(checkpoint_path) # Clear cache file to free disk test_dataset.cleanup_cache_files() train_dataset.cleanup_cache_files() if epoch_idx % 5 == 0: commit_checkpoint()