Spaces:
Running
Running
| import argparse | |
| import functools | |
| import os | |
| import platform | |
| import torch | |
| from peft import LoraConfig, get_peft_model, AdaLoraConfig, PeftModel, prepare_model_for_kbit_training | |
| from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, WhisperForConditionalGeneration, WhisperProcessor | |
| from utils.callback import SavePeftModelCallback | |
| from utils.data_utils import DataCollatorSpeechSeq2SeqWithPadding | |
| from utils.model_utils import load_from_checkpoint | |
| from utils.reader import CustomDataset | |
| from utils.utils import print_arguments, make_inputs_require_grad, add_arguments | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| add_arg = functools.partial(add_arguments, argparser=parser) | |
| add_arg("train_data", type=str, default="dataset/train.json", help="") | |
| add_arg("test_data", type=str, default="dataset/test.json", help="") | |
| add_arg("base_model", type=str, default="openai/whisper-tiny", help="Whisper") | |
| add_arg("output_dir", type=str, default="output/", help="") | |
| add_arg("warmup_steps", type=int, default=50, help="") | |
| add_arg("logging_steps", type=int, default=100, help="") | |
| add_arg("eval_steps", type=int, default=1000, help="") | |
| add_arg("save_steps", type=int, default=1000, help="") | |
| add_arg("num_workers", type=int, default=8, help="") | |
| add_arg("learning_rate", type=float, default=1e-3, help="") | |
| add_arg("min_audio_len", type=float, default=0.5, help="") | |
| add_arg("max_audio_len", type=float, default=30, help="") | |
| add_arg("use_adalora", type=bool, default=True, help="AdaLora/Lora") | |
| add_arg("fp16", type=bool, default=True, help="fp16") | |
| add_arg("use_8bit", type=bool, default=False, help="8 bit") | |
| add_arg("timestamps", type=bool, default=False, help="") | |
| add_arg("local_files_only", type=bool, default=False, help="") | |
| add_arg("num_train_epochs", type=int, default=3, help="") | |
| add_arg("language", type=str, default="bn", help="") | |
| add_arg("task", type=str, default="transcribe", choices=['transcribe', 'translate'], help="模型的任务") | |
| add_arg("augment_config_path", type=str, default=None, help="") | |
| add_arg("resume_from_checkpoint", type=str, default=None, help="") | |
| add_arg("per_device_train_batch_size", type=int, default=8, help="batch size") | |
| add_arg("per_device_eval_batch_size", type=int, default=8, help="batch size") | |
| add_arg("gradient_accumulation_steps", type=int, default=1, help="") | |
| args = parser.parse_args() | |
| print_arguments(args) | |
| # Whisper tokenizer | |
| processor = WhisperProcessor.from_pretrained(args.base_model, | |
| language=args.language, | |
| task=args.task, | |
| no_timestamps=not args.timestamps, | |
| local_files_only=args.local_files_only) | |
| # | |
| train_dataset = CustomDataset(data_list_path=args.train_data, | |
| processor=processor, | |
| language=args.language, | |
| timestamps=args.timestamps, | |
| min_duration=args.min_audio_len, | |
| max_duration=args.max_audio_len, | |
| augment_config_path=args.augment_config_path) | |
| test_dataset = CustomDataset(data_list_path=args.test_data, | |
| processor=processor, | |
| language=args.language, | |
| timestamps=args.timestamps, | |
| min_duration=args.min_audio_len, | |
| max_duration=args.max_audio_len) | |
| print(f"len train - {len(train_dataset)} test len - {len(test_dataset)}") | |
| # padding | |
| data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) | |
| # Whisper | |
| device_map = "auto" | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| ddp = world_size != 1 | |
| if ddp: | |
| device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} | |
| # | |
| model = WhisperForConditionalGeneration.from_pretrained(args.base_model, | |
| load_in_8bit=args.use_8bit, | |
| device_map=device_map, | |
| local_files_only=args.local_files_only) | |
| model.config.forced_decoder_ids = None | |
| model.config.suppress_tokens = [] | |
| # | |
| model = prepare_model_for_kbit_training(model) | |
| # forward,req grad | |
| model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad) | |
| print('加载LoRA模块...') | |
| if args.resume_from_checkpoint: | |
| # | |
| print("Loading adapters from checkpoint.") | |
| model = PeftModel.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True) | |
| else: | |
| print(f'adding LoRA modules...') | |
| target_modules = ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"] | |
| print(target_modules) | |
| if args.use_adalora: | |
| config = AdaLoraConfig(init_r=12, target_r=4, beta1=0.85, beta2=0.85, tinit=200, tfinal=1000, deltaT=10, | |
| lora_alpha=32, lora_dropout=0.1, orth_reg_weight=0.5, target_modules=target_modules) | |
| else: | |
| config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias="none") | |
| model = get_peft_model(model, config) | |
| output_dir = os.path.join(args.output_dir, os.path.basename(args.base_model)) | |
| # | |
| training_args = \ | |
| Seq2SeqTrainingArguments(output_dir=output_dir, # Directory to save checkpoints | |
| per_device_train_batch_size=args.per_device_train_batch_size, # Training batch_size size | |
| per_device_eval_batch_size=args.per_device_eval_batch_size, # Eval batch_size | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, # Cumulative steps of training gradient | |
| learning_rate=args.learning_rate, # learning rate size | |
| warmup_steps=args.warmup_steps, # Warm-up steps | |
| num_train_epochs=args.num_train_epochs, # epochs | |
| save_strategy="steps", # | |
| evaluation_strategy="steps", # | |
| load_best_model_at_end=True, # | |
| fp16=args.fp16, # | |
| report_to=["tensorboard"], # tensorboard | |
| save_steps=args.save_steps, # | |
| eval_steps=args.eval_steps, # | |
| save_total_limit=5, # | |
| optim='adamw_torch', # | |
| ddp_find_unused_parameters=False if ddp else None, # | |
| dataloader_num_workers=args.num_workers, # | |
| logging_steps=args.logging_steps, # | |
| remove_unused_columns=False, # | |
| label_names=["labels"]) # | |
| if training_args.local_rank == 0 or training_args.local_rank == -1: | |
| print('=' * 90) | |
| model.print_trainable_parameters() | |
| print('=' * 90) | |
| # Pytorch2.0 | |
| if torch.__version__ >= "2" and platform.system().lower() == 'windows': | |
| model = torch.compile(model) | |
| # | |
| trainer = Seq2SeqTrainer(args=training_args, | |
| model=model, | |
| train_dataset=train_dataset, | |
| eval_dataset=test_dataset, | |
| data_collator=data_collator, | |
| tokenizer=processor.feature_extractor, | |
| callbacks=[SavePeftModelCallback]) | |
| model.config.use_cache = False | |
| trainer._load_from_checkpoint = load_from_checkpoint | |
| # | |
| trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) | |
| # | |
| trainer.save_state() | |
| if training_args.local_rank == 0 or training_args.local_rank == -1: | |
| model.save_pretrained(os.path.join(output_dir, "checkpoint-final")) | |