fine tune memory?

#2
by glicerico - opened

I am trying to fine tune this model using deepspeed, as suggested in the model's repo: https://github.com/salesforce/jaxformer#a100-fine-tune
I have tried on up to 4 x A100 with a total of 360GB of RAM, but every time my training crashes before starting, after the memory gets fully used (monitored with htop).
How much memory do I need to fine tune this?

Salesforce org

Here is a configuration for deepspeed, which should fit on a single A100 with CPU offloading, however, this may be slow:
https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py

thanks for replying @enijkamp . This is exactly what I am trying to use (with my own training data, a longer run, and saving checkpoints), but as I say above, loading the model uses more than 360GB of RAM.
I am not sure if I am activating CPU offloading, though... I suppose the default params in that file are enough?

@enijkamp I've succeeded fine-tuning using TPU, but unfortunately can't find the 16B model checkpoints for this. I have read from last year issues that you haven't had time to upload sharding patterns... Any update on this?

Sign up or log in to comment