GLEN-model / scripts /test_small_training.ps1
QuanTH02's picture
15-06-v2
08894ba
#!/usr/bin/env pwsh
Write-Host "==========================================="
Write-Host "Testing GLEN with small Vault dataset"
Write-Host "==========================================="
# Set memory monitoring parameters
$GPU_MEMORY_THRESHOLD = 0.85
$GPU_CHECK_INTERVAL = 50
Write-Host "GPU Memory Protection enabled:"
Write-Host "- Memory threshold: ${GPU_MEMORY_THRESHOLD} (85%)"
Write-Host "- Check interval: ${GPU_CHECK_INTERVAL} steps"
Write-Host ""
# Ensure data preprocessing is done
Write-Host "Checking data preprocessing..."
if (-not (Test-Path "data/the_vault/DOC_VAULT_train.tsv")) {
Write-Host "Running data preprocessing..."
python scripts/preprocess_vault_dataset.py --input_dir the_vault_dataset/ --output_dir data/the_vault/ --sample_size 1000
if ($LASTEXITCODE -ne 0) {
Write-Error "Data preprocessing failed!"
exit 1
}
} else {
Write-Host "Data already preprocessed."
}
# Test Phase 1 Training
Write-Host ""
Write-Host "=== Phase 1 Training (Document ID Assignment) ==="
$env:CUDA_VISIBLE_DEVICES = "0"
try {
python examples/glen_phase1/train_glen.py `
--output_dir logs/test_glen_vault/GLEN_P1_test `
--model_name_or_path t5-base `
--query_type gtq_doc `
--per_device_train_batch_size 8 `
--per_device_eval_batch_size 4 `
--gradient_accumulation_steps 2 `
--dropout_rate 0.1 `
--Rdrop 0.15 `
--aug_query True `
--aug_query_type corrupted_query `
--input_dropout 1 `
--id_class t5_bm25_truncate_3 `
--dataset_name the_vault `
--test100 1 `
--tree 1 `
--pretrain_decoder True `
--max_input_length 128 `
--val_check_interval 1.0 `
--tie_word_embeddings True `
--decoder_input doc_rep `
--max_output_length 5 `
--num_return_sequences 5 `
--logging_steps 10 `
--overwrite_output_dir `
--wandb_tag test_glen_vault_p1 `
--do_eval False `
--num_train_epochs 1 `
--save_steps 50 `
--save_strategy steps `
--evaluation_strategy no `
--seed 42 `
--gpu_memory_threshold $GPU_MEMORY_THRESHOLD `
--gpu_check_interval $GPU_CHECK_INTERVAL `
--fp16 True
if ($LASTEXITCODE -ne 0) {
throw "Phase 1 training failed!"
}
} catch {
Write-Error "Phase 1 training failed: $_"
exit 1
}
Write-Host "βœ… Phase 1 training completed successfully!"
# Check if Phase 1 checkpoint exists
$PHASE1_CKPT = "logs/test_glen_vault/GLEN_P1_test"
if (-not (Test-Path $PHASE1_CKPT)) {
Write-Error "❌ Phase 1 checkpoint not found at $PHASE1_CKPT"
exit 1
}
# Check for model files
$model_files = @("pytorch_model.bin", "model.safetensors")
$found_model = $false
foreach ($file in $model_files) {
if (Test-Path "$PHASE1_CKPT/$file") {
$found_model = $true
Write-Host "πŸ“ Found Phase 1 model: $file"
break
}
}
if (-not $found_model) {
Write-Error "❌ No model files found in Phase 1 checkpoint"
exit 1
}
Write-Host ""
Write-Host "=== Phase 2 Training (Ranking-based Refinement) ==="
# Test Phase 2 Training
try {
python examples/glen_phase2/train_glen.py `
--output_dir logs/test_glen_vault/GLEN_P2_test `
--model_name_or_path $PHASE1_CKPT `
--per_device_train_batch_size 4 `
--per_device_eval_batch_size 2 `
--gradient_accumulation_steps 4 `
--dropout_rate 0.1 `
--warmup_ratio 0.1 `
--id_class t5_bm25_truncate_3 `
--dataset_name the_vault `
--test100 1 `
--tree 1 `
--q_max_len 32 `
--p_max_len 128 `
--negative_passage_type self `
--positive_passage_no_shuffle True `
--tie_word_embeddings True `
--num_return_sequences 5 `
--logging_steps 10 `
--overwrite_output_dir `
--wandb_tag test_glen_vault_p2 `
--do_eval False `
--num_train_epochs 1 `
--save_steps 50 `
--save_strategy steps `
--evaluation_strategy no `
--seed 42 `
--gpu_memory_threshold $GPU_MEMORY_THRESHOLD `
--gpu_check_interval $GPU_CHECK_INTERVAL `
--fp16 True
if ($LASTEXITCODE -ne 0) {
throw "Phase 2 training failed!"
}
} catch {
Write-Error "Phase 2 training failed: $_"
exit 1
}
Write-Host "βœ… Phase 2 training completed successfully!"
# Validate Phase 2 checkpoint
$PHASE2_CKPT = "logs/test_glen_vault/GLEN_P2_test"
if (-not (Test-Path $PHASE2_CKPT)) {
Write-Error "❌ Phase 2 checkpoint not found at $PHASE2_CKPT"
exit 1
}
# Check for checkpoint subdirectories or model files
$checkpoint_dirs = Get-ChildItem -Path $PHASE2_CKPT -Directory -Name "checkpoint-*" | Sort-Object {[int]($_.Split('-')[1])} | Select-Object -Last 1
if ($checkpoint_dirs) {
Write-Host "πŸ“ Found Phase 2 checkpoint: $checkpoint_dirs"
$checkpoint_path = "$PHASE2_CKPT/$checkpoint_dirs"
if (-not (Test-Path "$checkpoint_path/model.safetensors") -and -not (Test-Path "$checkpoint_path/pytorch_model.bin")) {
Write-Error "❌ No model files in checkpoint directory"
exit 1
}
} else {
# Check for model files in root
$found_model = $false
foreach ($file in $model_files) {
if (Test-Path "$PHASE2_CKPT/$file") {
$found_model = $true
Write-Host "πŸ“ Found Phase 2 model: $file"
break
}
}
if (-not $found_model) {
Write-Error "❌ No model files found in Phase 2 checkpoint"
exit 1
}
}
Write-Host ""
Write-Host "=== Document ID Generation ==="
try {
python examples/glen_phase2/makeid_glen.py `
--model_name_or_path $PHASE2_CKPT `
--infer_dir $PHASE2_CKPT `
--dataset_name the_vault `
--docid_file_name GLEN_P2_test_docids `
--per_device_eval_batch_size 4 `
--max_input_length 128 `
--num_return_sequences 10
if ($LASTEXITCODE -ne 0) {
throw "Document ID generation failed!"
}
} catch {
Write-Error "Document ID generation failed: $_"
exit 1
}
# Validate docid file was created
$docid_file = "logs/test_glen_vault/GLEN_P2_test_docids.tsv"
if (-not (Test-Path $docid_file)) {
Write-Error "❌ Document ID file not created: $docid_file"
exit 1
}
$line_count = (Get-Content $docid_file).Count
Write-Host "βœ… Document ID generation completed! Generated $line_count document IDs"
Write-Host ""
Write-Host "=== Query Inference ==="
try {
python examples/glen_phase2/evaluate_glen.py `
--model_name_or_path $PHASE2_CKPT `
--infer_dir $PHASE2_CKPT `
--dataset_name the_vault `
--docid_file_name GLEN_P2_test_docids `
--per_device_eval_batch_size 4 `
--q_max_len 32 `
--num_return_sequences 5 `
--logs_dir logs/test_glen_vault
if ($LASTEXITCODE -ne 0) {
throw "Query inference failed!"
}
} catch {
Write-Error "Query inference failed: $_"
exit 1
}
Write-Host "βœ… Query inference completed successfully!"
Write-Host ""
Write-Host "==========================================="
Write-Host "πŸŽ‰ ALL TESTS COMPLETED SUCCESSFULLY! πŸŽ‰"
Write-Host "==========================================="
Write-Host ""
Write-Host "πŸ“Š Summary:"
Write-Host " βœ… Phase 1 Training (Document ID Assignment)"
Write-Host " βœ… Phase 2 Training (Ranking-based Refinement)"
Write-Host " βœ… Document ID Generation ($line_count IDs)"
Write-Host " βœ… Query Inference & Evaluation"
Write-Host ""
Write-Host "πŸ“ Results saved in: logs/test_glen_vault/"
Write-Host "πŸ“ Document IDs: $docid_file"
Write-Host ""
Write-Host "πŸ›‘οΈ Memory Protection Summary:"
Write-Host " - GPU memory threshold: ${GPU_MEMORY_THRESHOLD} (85%)"
Write-Host " - Check interval: ${GPU_CHECK_INTERVAL} steps"
Write-Host " - FP16 training enabled"
Write-Host " - Optimized batch sizes used"
Write-Host ""
Write-Host "πŸš€ The system is ready for full training on The Vault dataset!"
Write-Host " Use scripts/train_full_vault.ps1 for production training."