|
#!/bin/bash |
|
|
|
USE_DDP=false |
|
|
|
|
|
PHASE1_CKPT="logs/model_glen_vault/GLEN_P1_base" |
|
|
|
|
|
GPU_MEMORY_THRESHOLD=0.85 |
|
GPU_CHECK_INTERVAL=50 |
|
|
|
if [ $USE_DDP = false ]; then |
|
|
|
CUDA_VISIBLE_DEVICES=0 \ |
|
python examples/glen_phase2/train_glen.py \ |
|
--output_dir logs/model_glen_vault/GLEN_P2_base \ |
|
--model_name_or_path ${PHASE1_CKPT} \ |
|
--load_best_model_at_end True \ |
|
--per_device_train_batch_size 4 \ |
|
--per_device_eval_batch_size 2 \ |
|
--gradient_accumulation_steps 32 \ |
|
--dropout_rate 0.1 \ |
|
--warmup_ratio 0.1 \ |
|
--id_class t5_bm25_truncate_3 \ |
|
--dataset_name the_vault \ |
|
--test100 1 \ |
|
--tree 1 \ |
|
--q_max_len 32 \ |
|
--p_max_len 256 \ |
|
--negative_passage_type self \ |
|
--positive_passage_no_shuffle True \ |
|
--tie_word_embeddings True \ |
|
--num_return_sequences 10 \ |
|
--logging_steps 100 \ |
|
--overwrite_output_dir \ |
|
--wandb_tag glen_vault_p2 \ |
|
--do_eval \ |
|
--seed 42 \ |
|
--gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \ |
|
--gpu_check_interval ${GPU_CHECK_INTERVAL} \ |
|
--fp16 True |
|
else |
|
|
|
CUDA_VISIBLE_DEVICES=0,1 \ |
|
python -m torch.distributed.launch --nproc_per_node=2 examples/glen_phase2/train_glen.py \ |
|
--ddp_find_unused_parameters False \ |
|
--output_dir logs/model_glen_vault/GLEN_P2_base \ |
|
--model_name_or_path ${PHASE1_CKPT} \ |
|
--load_best_model_at_end True \ |
|
--per_device_train_batch_size 4 \ |
|
--per_device_eval_batch_size 2 \ |
|
--gradient_accumulation_steps 32 \ |
|
--dropout_rate 0.1 \ |
|
--warmup_ratio 0.1 \ |
|
--id_class t5_bm25_truncate_3 \ |
|
--dataset_name the_vault \ |
|
--test100 1 \ |
|
--tree 1 \ |
|
--q_max_len 32 \ |
|
--p_max_len 256 \ |
|
--negative_passage_type self \ |
|
--positive_passage_no_shuffle True \ |
|
--tie_word_embeddings True \ |
|
--num_return_sequences 10 \ |
|
--logging_steps 100 \ |
|
--overwrite_output_dir \ |
|
--wandb_tag glen_vault_p2 \ |
|
--do_eval \ |
|
--seed 42 \ |
|
--gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \ |
|
--gpu_check_interval ${GPU_CHECK_INTERVAL} \ |
|
--fp16 True |
|
fi |