GLEN-model / scripts /train_full_vault.sh
QuanTH02's picture
15-06-v2
08894ba
#!/bin/bash
echo "==========================================="
echo "Full Training GLEN on The Vault dataset"
echo "==========================================="
# Set memory monitoring parameters
GPU_MEMORY_THRESHOLD=0.85
GPU_CHECK_INTERVAL=50
echo "GPU Memory Protection enabled:"
echo "- Memory threshold: ${GPU_MEMORY_THRESHOLD} (85%)"
echo "- Check interval: ${GPU_CHECK_INTERVAL} steps"
echo ""
# Ensure data preprocessing is done
echo "Checking data preprocessing..."
if [ ! -f "data/the_vault/DOC_VAULT_train.tsv" ] || [ ! -f "data/the_vault/GTQ_VAULT_dev.tsv" ]; then
echo "Running data preprocessing..."
python scripts/preprocess_vault_dataset.py --input_dir the_vault_dataset/ --output_dir data/the_vault/ --create_test_set
if [ $? -ne 0 ]; then
echo "Error: Data preprocessing failed!"
exit 1
fi
else
echo "Data already preprocessed."
fi
# Phase 1 Training
echo ""
echo "=== Phase 1 Training (Document ID Assignment) ==="
export CUDA_VISIBLE_DEVICES="0"
python examples/glen_phase1/train_glen.py \
--output_dir logs/glen_vault/GLEN_P1 \
--model_name_or_path t5-base \
--query_type gtq_doc \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 2 \
--dropout_rate 0.1 \
--Rdrop 0.15 \
--aug_query True \
--aug_query_type corrupted_query \
--input_dropout 1 \
--id_class t5_bm25_truncate_3 \
--dataset_name the_vault \
--test100 1 \
--tree 1 \
--pretrain_decoder True \
--max_input_length 128 \
--val_check_interval 1.0 \
--tie_word_embeddings True \
--decoder_input doc_rep \
--max_output_length 5 \
--num_return_sequences 5 \
--logging_steps 100 \
--overwrite_output_dir \
--wandb_tag glen_vault_p1 \
--do_eval True \
--num_train_epochs 3 \
--save_steps 1000 \
--save_strategy steps \
--evaluation_strategy steps \
--eval_steps 1000 \
--seed 42 \
--gpu_memory_threshold $GPU_MEMORY_THRESHOLD \
--gpu_check_interval $GPU_CHECK_INTERVAL \
--fp16 True
if [ $? -ne 0 ]; then
echo "Error: Phase 1 training failed!"
exit 1
fi
echo "βœ… Phase 1 training completed successfully!"
# Check if Phase 1 checkpoint exists
PHASE1_CKPT="logs/glen_vault/GLEN_P1"
if [ ! -d "$PHASE1_CKPT" ]; then
echo "Error: Phase 1 checkpoint not found at $PHASE1_CKPT"
exit 1
fi
# Check for model files
model_files=("pytorch_model.bin" "model.safetensors")
found_model=false
for file in "${model_files[@]}"; do
if [ -f "$PHASE1_CKPT/$file" ]; then
found_model=true
echo "πŸ“ Found Phase 1 model: $file"
break
fi
done
if [ "$found_model" = false ]; then
echo "Error: No model files found in Phase 1 checkpoint"
exit 1
fi
echo ""
echo "=== Phase 2 Training (Ranking-based Refinement) ==="
python examples/glen_phase2/train_glen.py \
--output_dir logs/glen_vault/GLEN_P2 \
--model_name_or_path $PHASE1_CKPT \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 4 \
--dropout_rate 0.1 \
--warmup_ratio 0.1 \
--id_class t5_bm25_truncate_3 \
--dataset_name the_vault \
--tree 1 \
--q_max_len 32 \
--p_max_len 128 \
--negative_passage_type self \
--positive_passage_no_shuffle True \
--tie_word_embeddings True \
--num_return_sequences 5 \
--logging_steps 100 \
--overwrite_output_dir \
--wandb_tag glen_vault_p2 \
--do_eval True \
--num_train_epochs 3 \
--save_steps 1000 \
--save_strategy steps \
--evaluation_strategy steps \
--eval_steps 1000 \
--seed 42 \
--gpu_memory_threshold $GPU_MEMORY_THRESHOLD \
--gpu_check_interval $GPU_CHECK_INTERVAL \
--fp16 True
if [ $? -ne 0 ]; then
echo "Error: Phase 2 training failed!"
exit 1
fi
echo "βœ… Phase 2 training completed successfully!"
# Validate Phase 2 checkpoint
PHASE2_CKPT="logs/glen_vault/GLEN_P2"
if [ ! -d "$PHASE2_CKPT" ]; then
echo "Error: Phase 2 checkpoint not found at $PHASE2_CKPT"
exit 1
fi
# Check for checkpoint subdirectories or model files
checkpoint_dir=$(find "$PHASE2_CKPT" -maxdepth 1 -type d -name "checkpoint-*" | sort -V | tail -n 1)
if [ -n "$checkpoint_dir" ]; then
echo "πŸ“ Found Phase 2 checkpoint: $(basename $checkpoint_dir)"
if [ ! -f "$checkpoint_dir/model.safetensors" ] && [ ! -f "$checkpoint_dir/pytorch_model.bin" ]; then
echo "Error: No model files in checkpoint directory"
exit 1
fi
else
# Check for model files in root
found_model=false
for file in "${model_files[@]}"; do
if [ -f "$PHASE2_CKPT/$file" ]; then
found_model=true
echo "πŸ“ Found Phase 2 model: $file"
break
fi
done
if [ "$found_model" = false ]; then
echo "Error: No model files found in Phase 2 checkpoint"
exit 1
fi
fi
echo ""
echo "=== Document ID Generation ==="
python examples/glen_phase2/makeid_glen.py \
--model_name_or_path $PHASE2_CKPT \
--infer_dir $PHASE2_CKPT \
--dataset_name the_vault \
--docid_file_name GLEN_P2_docids \
--per_device_eval_batch_size 4 \
--max_input_length 128 \
--num_return_sequences 10
if [ $? -ne 0 ]; then
echo "Error: Document ID generation failed!"
exit 1
fi
# Validate docid file was created
docid_file="logs/glen_vault/GLEN_P2_docids.tsv"
if [ ! -f "$docid_file" ]; then
echo "Error: Document ID file not created: $docid_file"
exit 1
fi
line_count=$(wc -l < "$docid_file")
echo "βœ… Document ID generation completed! Generated $line_count document IDs"
echo ""
echo "=== Query Inference ==="
# First, ensure we have test queries
if [ ! -f "data/the_vault/GTQ_VAULT_dev.tsv" ]; then
echo "Error: Test queries file not found. Please run preprocessing with --create_test_set flag"
exit 1
fi
python examples/glen_phase2/evaluate_glen.py \
--model_name_or_path $PHASE2_CKPT \
--infer_dir $PHASE2_CKPT \
--dataset_name the_vault \
--docid_file_name GLEN_P2_docids \
--per_device_eval_batch_size 4 \
--q_max_len 32 \
--num_return_sequences 5 \
--logs_dir logs/glen_vault \
--test100 1
if [ $? -ne 0 ]; then
echo "Error: Query inference failed!"
exit 1
fi
echo "βœ… Query inference completed successfully!"
echo ""
echo "==========================================="
echo "πŸŽ‰ FULL TRAINING COMPLETED SUCCESSFULLY! πŸŽ‰"
echo "==========================================="
echo ""
echo "πŸ“Š Summary:"
echo " βœ… Phase 1 Training (Document ID Assignment)"
echo " βœ… Phase 2 Training (Ranking-based Refinement)"
echo " βœ… Document ID Generation ($line_count IDs)"
echo " βœ… Query Inference & Evaluation"
echo ""
echo "πŸ“ Results saved in: logs/glen_vault/"
echo "πŸ“ Document IDs: $docid_file"
echo ""
echo "πŸ›‘οΈ Memory Protection Summary:"
echo " - GPU memory threshold: ${GPU_MEMORY_THRESHOLD} (85%)"
echo " - Check interval: ${GPU_CHECK_INTERVAL} steps"
echo " - FP16 training enabled"
echo " - Optimized batch sizes used"
echo ""
echo "πŸš€ Training completed! The model is ready for production use."