Join the conversation

Join the community of Machine Learners and AI enthusiasts.

Sign Up
m-ricย 
posted an update Mar 13
Post
Interesting paper: ๐†๐š๐‹๐จ๐ซ๐ž: ๐ญ๐ซ๐š๐ข๐ง ๐Ÿ•๐ ๐ฆ๐จ๐๐ž๐ฅ๐ฌ ๐จ๐ง ๐œ๐จ๐ง๐ฌ๐ฎ๐ฆ๐ž๐ซ-๐ ๐ซ๐š๐๐ž ๐†๐๐”๐ฌ ๐Ÿ’ช
It's now possible to ๐™›๐™ช๐™ก๐™ก๐™ฎ ๐™ฅ๐™ง๐™š-๐™ฉ๐™ง๐™–๐™ž๐™ฃ a 7B model on a consumer-grade GPU of 24Gb RAM, without any performance loss!

The memory usage of training models has always been an acute issue. For instance full pre-training of a 7B model used to eat ~50Gb of RAM!

The common workarounds to reduce memory load are:
- separate models on multiple GPUs ("sharding")
- quantize models: encode weights on fewer bits

Another technique is to ๐™ฅ๐™ง๐™ค๐™Ÿ๐™š๐™˜๐™ฉ ๐™ฉ๐™๐™š ๐™ฌ๐™š๐™ž๐™œ๐™๐™ฉ ๐™ข๐™–๐™ฉ๐™ง๐™ž๐™ญ ๐™ฉ๐™ค ๐™ก๐™ค๐™ฌ๐™š๐™ง-๐™ง๐™–๐™ฃ๐™  ๐™จ๐™ฅ๐™–๐™˜๐™š๐™จ, (since sometimes the weights do not really vary on all dimensions): this can save a lot of space!
This low-rank projection can be done on adapters to preserve the original weights (go check out LoRA), but it still generally hurts the performance too much for pre-training.

โžก๏ธ Enter the authors of ๐˜Ž๐˜ข๐˜“๐˜ฐ๐˜ณ๐˜ฆ: ๐˜”๐˜ฆ๐˜ฎ๐˜ฐ๐˜ณ๐˜บ-๐˜Œ๐˜ง๐˜ง๐˜ช๐˜ค๐˜ช๐˜ฆ๐˜ฏ๐˜ต ๐˜“๐˜“๐˜” ๐˜›๐˜ณ๐˜ข๐˜ช๐˜ฏ๐˜ช๐˜ฏ๐˜จ ๐˜ฃ๐˜บ ๐˜Ž๐˜ณ๐˜ข๐˜ฅ๐˜ช๐˜ฆ๐˜ฏ๐˜ต ๐˜“๐˜ฐ๐˜ธ-๐˜™๐˜ข๐˜ฏ๐˜ฌ ๐˜—๐˜ณ๐˜ฐ๐˜ซ๐˜ฆ๐˜ค๐˜ต๐˜ช๐˜ฐ๐˜ฏ. They gather (and prove) interesting insights:
โ›” The weight matrix does not reliably converge to lower ranks during training.
โœ… But the gradient matrix does!

Based on these insights, ๐˜๐—ต๐—ฒ๐˜† ๐—ฏ๐˜‚๐—ถ๐—น๐—ฑ ๐—š๐—ฎ๐—Ÿ๐—ผ๐—ฟ๐—ฒ, that projects the gradient to lower ranks.
๐Ÿ—บ๏ธ ๐—š๐—ฟ๐—ฒ๐—ฎ๐˜ ๐—ถ๐—ฑ๐—ฒ๐—ฎ: to leave the optimization free to explore more space, they periodically re-build the low-rank projection throughout the training (a nice illustration is in the paper).

๐Ÿค This method can even be combined with previous ones such as 8-bit Adam (quantizing the optimizer states to 8-bit).

โžก๏ธ ๐‘๐ž๐ฌ๐ฎ๐ฅ๐ญ๐ฌ:
๐Ÿ“‰ Of course, huge reduction in memory footprint allowing the training on consumer-grade GPU (cf figure).
๐Ÿ’ช No reduction in performance: this scales well up to 7B parameters (and was independently confirmed since) โ‡’ this is essential, it confirms that the method is viable!

Read the full paper here: GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection (2403.03507)
In this post