vaw2tmp / main.py
Check's picture
fix error when read shard
2808233
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
TrainerCallback
from datasets import load_from_disk
from data_handler import DataCollatorCTCWithPadding
from transformers import TrainingArguments
from transformers import Trainer, logging
from metric_utils import compute_metrics_fn
from transformers.trainer_utils import get_last_checkpoint
import json
import os, glob
from callbacks import BreakEachEpoch
import subprocess
from multiprocessing import Process
import shutil
logging.set_verbosity_info()
def load_pretrained_model(checkpoint_path=None):
if checkpoint_path is None:
pre_trained_path = './model-bin/pretrained/base'
tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
unk_token="<unk>",
pad_token="<pad>",
word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained(
pre_trained_path,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)
model.freeze_feature_extractor()
else:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained(
checkpoint_path,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)
# model.freeze_feature_extractor()
# model = Wav2Vec2ForCTC(model.config)
model_total_params = sum(p.numel() for p in model.parameters())
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
model_total_params_trainable))
return model, processor
def prepare_dataset(batch, processor):
# check that all files have the correct sampling rate
assert (
len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
batch["length"] = [len(item) for item in batch["input_values"]]
with processor.as_target_processor():
batch["labels"] = processor(batch["target_text"]).input_ids
return batch
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=5):
try:
dataset = load_from_disk(path)
list_cache_prefetch_files = glob.glob(
cache_file_map_name.replace(cache_processing_dataset_folder, cache_processing_dataset_folder_prefetch).replace(
'.arrow', '*'))
# Do not re-compute what already in cache folder
if cache_file_map_name.startswith(cache_processing_dataset_folder_prefetch):
if len(glob.glob(cache_file_map_name.replace(cache_processing_dataset_folder_prefetch,
cache_processing_dataset_folder).replace('.arrow', '*'))) > 0:
return
if len(list_cache_prefetch_files) > 0:
return
# check cache file
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) == 0 and len(list_cache_prefetch_files) > 0:
for item_file in list_cache_prefetch_files:
shutil.move(item_file, item_file.replace(cache_processing_dataset_folder_prefetch,
cache_processing_dataset_folder))
if len(glob.glob(cache_file_map_name.replace('.arrow', '*'))) > 0:
return dataset.map(prepare_dataset,
remove_columns=dataset.column_names,
batch_size=32,
num_proc=num_proc,
batched=True,
fn_kwargs={"processor": processor},
cache_file_name=cache_file_map_name)
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
batch_size=32,
num_proc=num_proc,
cache_file_name=cache_file_filter_name)
processed_dataset = dataset.map(prepare_dataset,
remove_columns=dataset.column_names,
batch_size=32,
num_proc=num_proc,
batched=True,
fn_kwargs={"processor": processor},
cache_file_name=cache_file_map_name)
processed_dataset.cleanup_cache_files()
return processed_dataset
except:
return None
def commit_checkpoint():
submit_commands = [
'git add model-bin/finetune/base/*',
'git commit -m "auto-commit"',
'git push origin main'
]
for command in submit_commands:
print(subprocess.run(command.split(), stdout=subprocess.PIPE).stdout.decode('utf-8'))
def get_train_test_shard_id(epoch_count):
# loop over training shards
_train_dataset_shard_idx = epoch_count % num_train_shards
# Get test shard depend on train shard id
_test_dataset_shard_idx = min(round(_train_dataset_shard_idx / (num_train_shards / num_test_shards)), num_test_shards - 1)
_num_test_sub_shard = 8 # Split test shard into subset. Default is 8
_idx_sub_shard = _train_dataset_shard_idx % _num_test_sub_shard # loop over test shard subset
return _train_dataset_shard_idx, _test_dataset_shard_idx, _num_test_sub_shard, _idx_sub_shard
def process_prefetch_epoch(epoch_count):
train_shard_idx, test_shard_idx, _, _ = get_train_test_shard_id(epoch_count)
load_prepared_dataset(os.path.join(train_dataset_root_folder,
'shard_{}'.format(train_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
'train',
'cache-train-filter-shard-{}.arrow'.format(
train_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch,
'train',
'cache-train-map-shard-{}.arrow'.format(
train_shard_idx)),
)
load_prepared_dataset(os.path.join(test_dataset_root_folder,
'shard_{}'.format(test_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder_prefetch,
'test',
'cache-test-filter-shard-{}.arrow'.format(
test_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder_prefetch, 'test',
'cache-test-map-shard-{}.arrow'.format(
test_shard_idx))
)
if __name__ == "__main__":
checkpoint_path = "./model-bin/finetune/base/"
# train_dataset_root_folder = './data-bin/train_dataset'
# test_dataset_root_folder = './data-bin/test_dataset'
train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
cache_processing_dataset_folder = '/dev/shm/cache/'
cache_processing_dataset_folder_prefetch = './data-bin/cache_prefetch/'
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
if not os.path.exists(os.path.join(cache_processing_dataset_folder_prefetch, 'train')):
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'train'))
os.makedirs(os.path.join(cache_processing_dataset_folder_prefetch, 'test'))
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
num_epochs = 5000
training_args = TrainingArguments(
output_dir=checkpoint_path,
fp16=True,
group_by_length=True,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
gradient_accumulation_steps=2,
num_train_epochs=num_epochs, # each epoch per shard data
logging_steps=5,
learning_rate=1e-5,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
ignore_data_skip=True,
logging_dir=os.path.join(checkpoint_path, 'log'),
metric_for_best_model='wer',
save_strategy="epoch",
evaluation_strategy="epoch",
greater_is_better=False,
# save_steps=5,
# eval_steps=5,
)
trainer = None
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
last_checkpoint_path = None
last_epoch_idx = 0
if os.path.exists(checkpoint_path):
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
if last_checkpoint_path is not None:
with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
trainer_state = json.load(file)
last_epoch_idx = int(trainer_state['epoch'])
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
prefetch_process = []
for epoch_idx in range(last_epoch_idx, num_epochs):
# # loop over training shards
# train_dataset_shard_idx = epoch_idx % num_train_shards
# # Get test shard depend on train shard id
# test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
# num_test_sub_shard = 8 # Split test shard into subset. Default is 8
# idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
train_dataset_shard_idx, test_dataset_shard_idx, num_test_sub_shard, idx_sub_shard = get_train_test_shard_id(
epoch_idx)
# waiting for all prefetch process done
for process_instance in prefetch_process:
process_instance.join()
prefetch_process.clear()
# load train shard
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
'shard_{}'.format(train_dataset_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
'train',
'cache-train-filter-shard-{}.arrow'.format(
train_dataset_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder,
'train',
'cache-train-map-shard-{}.arrow'.format(
train_dataset_shard_idx)),
) # .shard(1000, 0) # Remove shard split when train
# load test shard subset
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
'shard_{}'.format(test_dataset_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
'test',
'cache-test-filter-shard-{}.arrow'.format(
test_dataset_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
'cache-test-map-shard-{}.arrow'.format(
test_dataset_shard_idx))
)
if train_dataset is None or test_dataset is None:
print("Ignore Shard {}".format(train_dataset_shard_idx))
continue
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
# Prefetch_dataset
prefetch_process.append(Process(target=process_prefetch_epoch, args=(epoch_idx + 1,)))
for process_instance in prefetch_process:
process_instance.start()
# Init trainer
if trainer is None:
trainer = Trainer(
model=w2v_ctc_model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=w2v_ctc_processor.feature_extractor,
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
)
else:
trainer.train_dataset = train_dataset
trainer.eval_dataset = test_dataset
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
logging.get_logger().info(
'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))
if last_checkpoint_path is not None:
# start train from a checkpoint if exist
trainer.train(resume_from_checkpoint=True)
else:
# train from pre-trained wav2vec2 checkpoint
trainer.train()
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
# Clear cache file to free disk
test_dataset.cleanup_cache_files()
train_dataset.cleanup_cache_files()
if epoch_idx % 5 == 0:
commit_checkpoint()