High GPU Memory Usage with Multiple Image Inputs in Idefics2 Fine-Tuning

#66
by dyliu - opened

Hi,

Thanks for the amazing open-source models!

I've been fine-tuning the Idefics2 model using the official Colab tutorial you provided here and tried to finetune the model on multi-image input.

Here is the only modification I made for the tutorial code:
image.png

I've encountered an issue with GPU memory usage. When I repeat the image in the input twice, the peak GPU usage increases to around 30GB. With five-image per input, it jumps to approximately 60GB. The model itself only requires 7GB (all numbers under QLoRA setting, float16, batch size of 2). This suggests that each additional image in the input uses around 3GB of memory. I've ensured that do_splitting is set to False.

Is this behavior expected, or could there be an optimization issue? Any insights or suggestions would be greatly appreciated.

Thank you and looking forward to your reply!

hi @dyliu
yes the vision encoder is memory hungry indeed.
Are you using flash attention and fp16/bf16? it seriously helps with memory.
Beyond that, I would also recommend activating gradient checkpointing if you are doing training

Thanks for your prompt reply, @VictorSanh ! How can I understand the memory usage from the vision encoder? From my understanding, each image is split into 64 patches (with do_splitting off), and each patch has its own embeddings of the same size as word embeddings. If I understand correctly, do these 64 tokens take ~3GB of memory?

when image_splitting is False, a single is resized so that the longest side length is between 378 and 980 (while preserving the aspect ratio) and then splitted into square patches of 14 x 14. the image is then represented by a sequence of N hidden states where N is the number of 14 x 14 patches in the image. that sequence of N hidden states goes through the transformer

thanks for the clarification! two questions:

  1. why 14x14? I read in the blog that the number of tokens per image is either 64 or 320, depending on do_splitting.
  2. 1If 14x14 means there are 256 tokens per image, and assuming the embedding size is 1024 with each parameter being float32 (4B), then 256 tokens should take up approximately 256 x 1024 x 4B = ~1MB. Why does it take about 3GB in the experiment then?

1/ 14 is the hyper-parameter that the underlying siglip chose for the patches! at the end of the vision transformer, the images are systematically pooled into a shorter sequence (64 visual tokens) with a perceiver resampler. 520 = 5 * 64

2/ 14 x 14 does not mean there will be systematically 256 tokens per image. N varies depending on the size of the image.
operations in the transformer do take memory to store activations for instance.

VictorSanh changed discussion status to closed

Sign up or log in to comment