Olive_Whisper_ASR / utils /callback.py
sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
6de3e11
raw
history blame
1.7 kB
import os
import os
import shutil
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
#
class SavePeftModelCallback(TrainerCallback):
def on_save(self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs, ):
if args.local_rank == 0 or args.local_rank == -1:
#
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_dir = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_dir)
peft_config_path = os.path.join(checkpoint_folder, "adapter_model/adapter_config.json")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model/adapter_model.bin")
if not os.path.exists(peft_config_path):
os.remove(peft_config_path)
if not os.path.exists(peft_model_path):
os.remove(peft_model_path)
if os.path.exists(peft_model_dir):
shutil.rmtree(peft_model_dir)
#
best_checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-best")
#
if os.path.exists(state.best_model_checkpoint):
if os.path.exists(best_checkpoint_folder):
shutil.rmtree(best_checkpoint_folder)
shutil.copytree(state.best_model_checkpoint, best_checkpoint_folder)
print(f"{state.best_model_checkpoint}{state.best_metric}")
return control