|
#!/bin/bash |
|
|
|
accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \ |
|
--model_name_or_path distil-whisper/large-32-2 \ |
|
--teacher_model_name_or_path openai/whisper-large-v2 \ |
|
--train_dataset_config_name all+all+all+l \ |
|
--train_dataset_samples 2.9+10.4+14.9+226.6 \ |
|
--train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \ |
|
--train_split_name train.clean.100+train.clean.360+train.other.500+train \ |
|
--eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \ |
|
--eval_dataset_config_name all+all+l \ |
|
--eval_split_name validation.clean+validation.other+validation \ |
|
--eval_text_column_name text+text+text \ |
|
--eval_steps 2500 \ |
|
--save_steps 2500 \ |
|
--warmup_steps 50 \ |
|
--learning_rate 0.0001 \ |
|
--lr_scheduler_type constant_with_warmup \ |
|
--logging_steps 25 \ |
|
--save_total_limit 1 \ |
|
--max_steps 10000 \ |
|
--wer_threshold 10 \ |
|
--per_device_train_batch_size 64 \ |
|
--gradient_accumulation_steps 2 \ |
|
--per_device_eval_batch_size 64 \ |
|
--dataloader_num_workers 16 \ |
|
--cache_dir /fsx/sanchit/cache \ |
|
--dataset_cache_dir /fsx/sanchit/cache \ |
|
--dtype bfloat16 \ |
|
--output_dir ./ \ |
|
--wandb_project distil-whisper-training \ |
|
--do_train \ |
|
--do_eval \ |
|
--gradient_checkpointing \ |
|
--overwrite_output_dir \ |
|
--predict_with_generate \ |
|
--freeze_encoder \ |
|
--streaming |
|
|