How to properly save/load the 8-bit model?

#12
by Beyondo - opened

Hello, firstly thank you for making GPT-J and many other language models more accessible!

I've successfully converted GPT-Neo 1.3B model using the notebook you linked. But when using the transformers save_pretrained() function, and I try to load the model back, it just tells me that the weights are destroyed; and when I try to use the model (after reloading from the saved files), it is dead.

So assuming I want train the model in 8-bit with mixed precision using your notebook, after the training is done, should I just convert it back to 32-bit or what? And if so, what's the proper way for a reverse_bnbfy?

Or maybe I can somehow save and load the 8-bit weights back without resorting to saving in 32-bit, so can someone guide me how?

Sign up or log in to comment