How to train with long inputs (Training memory requirement)

#62
by jiang719 - opened

Hi,

StarCoder supports input up to 8192 tokens, so I assume you also train the model with such long input. But when I was trying to fine-tune it, I found I cannot even use input with 2048 tokens.

Even with 4 A100 80G, and half precision enabled, deepspeed's ZERO3 enabled, param/optimizer offload opened, and gradient checkpointing enabled. I still got OOM with even batch_size_per_gpu set to 1. And I monitor the GPU usage, it seems the forward of 2048 tokens already take all the GPU memory. That means, the OOM happened even before the backward() is called

I wonder how could you train it with even longer inputs? Or any suggestions on training/fine-tuning with long inputs? Thank you!

We used tensor parallelism and pipeline parallelism when training with Megatron-LM to split the model on multiple GPUs (you can find details here). You can also use PEFT fine-tuning which requires much less memory and will allow you to fit a large context (example here)

Hello @loubnabnl with Tensor parallelism=4 and pipeline parallelism=4, we need 4*4 16GPUs right, correct me if i am wrong.
Even with peft-lora I am not able to finetune on 8K with 6 A100 40GB GPUs, even with qlora also i am not able to train starcoder model on 8K.
Any Suggestions? Am i missing anything? Please Suggest. Thank you!

Sign up or log in to comment