GLEN-model / scripts /train_glen_p2_vault.sh
QuanTH02's picture
Commit 15-06-v1
6534252
#!/bin/bash
USE_DDP=false
# Phase 1 checkpoint path
PHASE1_CKPT="logs/model_glen_vault/GLEN_P1_base"
# GPU Memory monitoring settings
GPU_MEMORY_THRESHOLD=0.85 # 85% of GPU memory
GPU_CHECK_INTERVAL=50 # Check every 50 steps
if [ $USE_DDP = false ]; then
# Without distributed training
CUDA_VISIBLE_DEVICES=0 \
python examples/glen_phase2/train_glen.py \
--output_dir logs/model_glen_vault/GLEN_P2_base \
--model_name_or_path ${PHASE1_CKPT} \
--load_best_model_at_end True \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 32 \
--dropout_rate 0.1 \
--warmup_ratio 0.1 \
--id_class t5_bm25_truncate_3 \
--dataset_name the_vault \
--test100 1 \
--tree 1 \
--q_max_len 32 \
--p_max_len 256 \
--negative_passage_type self \
--positive_passage_no_shuffle True \
--tie_word_embeddings True \
--num_return_sequences 10 \
--logging_steps 100 \
--overwrite_output_dir \
--wandb_tag glen_vault_p2 \
--do_eval \
--seed 42 \
--gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \
--gpu_check_interval ${GPU_CHECK_INTERVAL} \
--fp16 True
else
# With distributed training
CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch --nproc_per_node=2 examples/glen_phase2/train_glen.py \
--ddp_find_unused_parameters False \
--output_dir logs/model_glen_vault/GLEN_P2_base \
--model_name_or_path ${PHASE1_CKPT} \
--load_best_model_at_end True \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 32 \
--dropout_rate 0.1 \
--warmup_ratio 0.1 \
--id_class t5_bm25_truncate_3 \
--dataset_name the_vault \
--test100 1 \
--tree 1 \
--q_max_len 32 \
--p_max_len 256 \
--negative_passage_type self \
--positive_passage_no_shuffle True \
--tie_word_embeddings True \
--num_return_sequences 10 \
--logging_steps 100 \
--overwrite_output_dir \
--wandb_tag glen_vault_p2 \
--do_eval \
--seed 42 \
--gpu_memory_threshold ${GPU_MEMORY_THRESHOLD} \
--gpu_check_interval ${GPU_CHECK_INTERVAL} \
--fp16 True
fi