Fine-tuning bge-reranker-v2-gemma resulted in CUDA torch.cuda.OutOfMemoryError even with 4 GPUs

#2
by jackkwok - opened

I am fine-tuning bge-reranker-v2-gemma with my custom training dataset. I am using 4x NVIDIA A10G with 24GB memory each so it's quite a lot of memory. But, I still get CUDA OOM shortly into the training. Any idea?

My command:

torchrun --nproc_per_node 4 \
-m FlagEmbedding.llm_reranker.finetune_for_instruction.run \
--output_dir model_artifacts \
--token <secret redacted> \
--model_name_or_path google/gemma-2b \
--train_data ./jsonl/train.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj

I still get CUDA OOM shortly into the training:

{'loss': 2.9226, 'grad_norm': 5.388334274291992, 'learning_rate': 0.0, 'epoch': 0.0}                                                                    
{'loss': 2.8772, 'grad_norm': 4.197434902191162, 'learning_rate': 3.562071871080222e-05, 'epoch': 0.0}                                                  
{'loss': 2.8868, 'grad_norm': 4.446421146392822, 'learning_rate': 5.645750340535797e-05, 'epoch': 0.01}                                                 
{'loss': 2.6807, 'grad_norm': 3.2873618602752686, 'learning_rate': 7.124143742160444e-05, 'epoch': 0.01}                                                
{'loss': 2.6077, 'grad_norm': 3.4100160598754883, 'learning_rate': 8.270874753469163e-05, 'epoch': 0.01}                                                
{'loss': 2.5243, 'grad_norm': 3.642030715942383, 'learning_rate': 9.207822211616019e-05, 'epoch': 0.01}                                                 
{'loss': 2.6225, 'grad_norm': 3.3111226558685303, 'learning_rate': 0.0001, 'epoch': 0.01}                                                               
{'loss': 2.7275, 'grad_norm': 4.486476898193359, 'learning_rate': 0.00010686215613240667, 'epoch': 0.02}                                                
{'loss': 2.7934, 'grad_norm': 3.5005621910095215, 'learning_rate': 0.00011291500681071594, 'epoch': 0.02}                                               
{'loss': 2.6622, 'grad_norm': 3.005181312561035, 'learning_rate': 0.00011832946624549386, 'epoch': 0.02}                                                
{'loss': 2.6258, 'grad_norm': 3.1711008548736572, 'learning_rate': 0.0001232274405867344, 'epoch': 0.02}                                                
{'loss': 2.489, 'grad_norm': 3.336585283279419, 'learning_rate': 0.0001276989408269624, 'epoch': 0.02}                                                  
{'loss': 2.4983, 'grad_norm': 3.2431273460388184, 'learning_rate': 0.0001318123223061841, 'epoch': 0.03}                                                
{'loss': 2.4421, 'grad_norm': 3.682769298553467, 'learning_rate': 0.00013562071871080222, 'epoch': 0.03}                                                
{'loss': 2.6211, 'grad_norm': 3.925990343093872, 'learning_rate': 0.0001391662509400496, 'epoch': 0.03}                                                 
{'loss': 2.8245, 'grad_norm': 4.797396659851074, 'learning_rate': 0.00014248287484320887, 'epoch': 0.03}                                                
{'loss': 2.5547, 'grad_norm': 4.069711685180664, 'learning_rate': 0.0001455983641090348, 'epoch': 0.03}                                                 
{'loss': 2.8645, 'grad_norm': 5.024136543273926, 'learning_rate': 0.00014853572552151815, 'epoch': 0.04}                                                
{'loss': 2.3737, 'grad_norm': 3.9875905513763428, 'learning_rate': 0.00015131423106025147, 'epoch': 0.04}                                               
{'loss': 2.6015, 'grad_norm': 3.5503010749816895, 'learning_rate': 0.00015395018495629608, 'epoch': 0.04}                                               
{'loss': 2.4038, 'grad_norm': 4.3279194831848145, 'learning_rate': 0.000156457503405358, 'epoch': 0.04}                                                 
{'loss': 2.374, 'grad_norm': 3.7719438076019287, 'learning_rate': 0.00015884815929753662, 'epoch': 0.04}                                                
{'loss': 2.2142, 'grad_norm': 3.907940626144409, 'learning_rate': 0.00016113252800759313, 'epoch': 0.05}                                                
{'loss': 2.526, 'grad_norm': 3.962578296661377, 'learning_rate': 0.00016331965953776464, 'epoch': 0.05}                                                 
{'loss': 2.4629, 'grad_norm': 4.234306812286377, 'learning_rate': 0.00016541749506938325, 'epoch': 0.05}                                                
{'loss': 2.0923, 'grad_norm': 4.046939373016357, 'learning_rate': 0.00016743304101698634, 'epoch': 0.05}                                                
{'loss': 2.4197, 'grad_norm': 5.140893459320068, 'learning_rate': 0.0001693725102160739, 'epoch': 0.06}                                                 
{'loss': 2.266, 'grad_norm': 4.84731912612915, 'learning_rate': 0.00017124143742160445, 'epoch': 0.06}                                                                              
{'loss': 2.3097, 'grad_norm': 4.363345623016357, 'learning_rate': 0.00017304477452986233, 'epoch': 0.06}                                                                            
{'loss': 2.2013, 'grad_norm': 6.07499885559082, 'learning_rate': 0.00017478696965085182, 'epoch': 0.06}                                                                             
{'loss': 2.3062, 'grad_norm': 4.905301570892334, 'learning_rate': 0.0001764720332103851, 'epoch': 0.06}                                                                             
{'loss': 2.1395, 'grad_norm': 11.986255645751953, 'learning_rate': 0.00017810359355401113, 'epoch': 0.07}                                                                           
{'loss': 2.3626, 'grad_norm': 7.714587211608887, 'learning_rate': 0.00017968494399209236, 'epoch': 0.07}                                                                            
{'loss': 2.3273, 'grad_norm': 10.299479484558105, 'learning_rate': 0.00018121908281983702, 'epoch': 0.07}                                                                           
{'loss': 2.1057, 'grad_norm': 6.595864295959473, 'learning_rate': 0.00018270874753469163, 'epoch': 0.07}                                                                            
{'loss': 2.3477, 'grad_norm': 9.525544166564941, 'learning_rate': 0.00018415644423232038, 'epoch': 0.07}                                                                            
{'loss': 2.027, 'grad_norm': 13.262056350708008, 'learning_rate': 0.00018556447297411074, 'epoch': 0.08}                                                                            
{'loss': 2.087, 'grad_norm': 8.810412406921387, 'learning_rate': 0.0001869349497710537, 'epoch': 0.08}                                                                              
{'loss': 2.1892, 'grad_norm': 6.087996006011963, 'learning_rate': 0.00018826982571154205, 'epoch': 0.08}                                                                            
{'loss': 1.5995, 'grad_norm': 7.068698406219482, 'learning_rate': 0.00018957090366709828, 'epoch': 0.08}                                                                            
  8%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                                                                                                                | 40/490 [33:43<6:10:42, 49.43s/it]Traceback (most recent call last):
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/FlagEmbedding/llm_reranker/finetune_for_instruction/run.py", line 131, in <module>
    main()
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/FlagEmbedding/llm_reranker/finetune_for_instruction/run.py", line 118, in main
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/transformers/trainer.py", line 3250, in training_step
    self.accelerator.backward(loss)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/accelerate/accelerator.py", line 2117, in backward
    self.deepspeed_engine_wrapped.backward(loss, **kwargs)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/accelerate/utils/deepspeed.py", line 166, in backward
    self.engine.backward(loss, **kwargs)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1976, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 2056, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/home/ec2-user/SageMaker/custom-miniconda/miniconda/envs/faiss/lib/python3.8/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 7.32 GiB. GPU 0 has a total capacty of 22.19 GiB of which 7.13 GiB is free. Including non-PyTorch memory, this process has 15.06 GiB memory in use. Of the allocated memory 14.13 GiB is allocated by PyTorch, and 535.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My dependencies:

sentence-transformers~=2.7.0
FlagEmbedding~=1.2.10
peft~=0.11.1
deepspeed~=0.14.2
flash-attn~=2.5.9.post1

Issue fixed by decreasing both of these parameters:

--query_max_len
--passage_max_len

Sign up or log in to comment