Olive_Whisper_ASR / finetune.py
sam2ai's picture
Synced repo using 'sync_with_huggingface' Github Action
6de3e11
raw
history blame
7.93 kB
import argparse
import functools
import os
import platform
import torch
from peft import LoraConfig, get_peft_model, AdaLoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration, WhisperProcessor
from utils.callback import SavePeftModelCallback
from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding
from utils.model_utils import load_from_checkpoint
from utils.reader import CustomDataset
from utils.utils import print_arguments, make_inputs_require_grad, add_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("train_data", type=str, default="dataset/train.json", help="")
add_arg("test_data", type=str, default="dataset/test.json", help="")
add_arg("base_model", type=str, default="openai/whisper-tiny", help="Whisper")
add_arg("output_dir", type=str, default="output/", help="")
add_arg("warmup_steps", type=int, default=50, help="")
add_arg("logging_steps", type=int, default=100, help="")
add_arg("eval_steps", type=int, default=1000, help="")
add_arg("save_steps", type=int, default=1000, help="")
add_arg("num_workers", type=int, default=8, help="")
add_arg("learning_rate", type=float, default=1e-3, help="")
add_arg("min_audio_len", type=float, default=0.5, help="")
add_arg("max_audio_len", type=float, default=30, help="")
add_arg("use_adalora", type=bool, default=True, help="AdaLora/Lora")
add_arg("fp16", type=bool, default=True, help="fp16")
add_arg("use_8bit", type=bool, default=False, help="8 bit")
add_arg("timestamps", type=bool, default=False, help="")
add_arg("local_files_only", type=bool, default=False, help="")
add_arg("num_train_epochs", type=int, default=3, help="")
add_arg("language", type=str, default="bn", help="")
add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务")
add_arg("augment_config_path", type=str, default=None, help="")
add_arg("resume_from_checkpoint", type=str, default=None, help="")
add_arg("per_device_train_batch_size", type=int, default=8, help="batch size")
add_arg("per_device_eval_batch_size", type=int, default=8, help="batch size")
add_arg("gradient_accumulation_steps", type=int, default=1, help="")
args = parser.parse_args()
print_arguments(args)
# Whisper tokenizer
processor = WhisperProcessor.from_pretrained(args.base_model,
language=args.language,
task=args.task,
no_timestamps=not args.timestamps,
local_files_only=args.local_files_only)
#
train_dataset = CustomDataset(data_list_path=args.train_data,
processor=processor,
language=args.language,
timestamps=args.timestamps,
min_duration=args.min_audio_len,
max_duration=args.max_audio_len,
augment_config_path=args.augment_config_path)
test_dataset = CustomDataset(data_list_path=args.test_data,
processor=processor,
language=args.language,
timestamps=args.timestamps,
min_duration=args.min_audio_len,
max_duration=args.max_audio_len)
print(f"len train - {len(train_dataset)} test len - {len(test_dataset)}")
# padding
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
# Whisper
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
#
model = WhisperForConditionalGeneration.from_pretrained(args.base_model,
load_in_8bit=args.use_8bit,
device_map=device_map,
local_files_only=args.local_files_only)
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
#
model = prepare_model_for_kbit_training(model)
# forward,req grad
model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
print('加载LoRA模块...')
if args.resume_from_checkpoint:
#
print("Loading adapters from checkpoint.")
model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True)
else:
print(f'adding LoRA modules...')
target_modules = ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"]
print(target_modules)
if args.use_adalora:
config = AdaLoraConfig(init_r=12, target_r=4, beta1=0.85, beta2=0.85, tinit=200, tfinal=1000, deltaT=10,
lora_alpha=32, lora_dropout=0.1, orth_reg_weight=0.5, target_modules=target_modules)
else:
config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none")
model = get_peft_model(model, config)
output_dir = os.path.join(args.output_dir, os.path.basename(args.base_model))
#
training_args = \
Seq2SeqTrainingArguments(output_dir=output_dir, # Directory to save checkpoints
per_device_train_batch_size=args.per_device_train_batch_size, # Training batch_size size
per_device_eval_batch_size=args.per_device_eval_batch_size, # Eval batch_size
gradient_accumulation_steps=args.gradient_accumulation_steps, # Cumulative steps of training gradient
learning_rate=args.learning_rate, # learning rate size
warmup_steps=args.warmup_steps, # Warm-up steps
num_train_epochs=args.num_train_epochs, # epochs
save_strategy="steps", #
evaluation_strategy="steps", #
load_best_model_at_end=True, #
fp16=args.fp16, #
report_to=["tensorboard"], # tensorboard
save_steps=args.save_steps, #
eval_steps=args.eval_steps, #
save_total_limit=5, #
optim='adamw_torch', #
ddp_find_unused_parameters=False if ddp else None, #
dataloader_num_workers=args.num_workers, #
logging_steps=args.logging_steps, #
remove_unused_columns=False, #
label_names=["labels"]) #
if training_args.local_rank == 0 or training_args.local_rank == -1:
print('=' * 90)
model.print_trainable_parameters()
print('=' * 90)
# Pytorch2.0
if torch.__version__ >= "2" and platform.system().lower() == 'windows':
model = torch.compile(model)
#
trainer = Seq2SeqTrainer(args=training_args,
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
data_collator=data_collator,
tokenizer=processor.feature_extractor,
callbacks=[SavePeftModelCallback])
model.config.use_cache = False
trainer._load_from_checkpoint = load_from_checkpoint
#
trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
#
trainer.save_state()
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.save_pretrained(os.path.join(output_dir, "checkpoint-final"))