Text Generation
Transformers
Safetensors
dbrx
conversational
text-generation-inference

Errors During Training for the Original Implementation and the Fixes for the Errors

#24
by v2ray - opened

https://huggingface.co/v2ray/dbrx-base-fixed
The original DBRX implementation code has a few bugs which only affect training, which I fixed in my re-upload.
I re-uploaded because the changes require the weights files to be converted, so if anyone want to use the fix you need to re-download the entire weights!

The issues - How I fixed them:

  1. Error when using gradient checkpointing - Fixed by using positional arguments instead because _gradient_checkpointing_func doesn't support kwargs.
  2. VRAM usage go zoom and CUDA Out of Memory when backpropping through the MLP layer - Fixed by separating the experts' weights into different tensors instead of using a single tensor for all the experts. IDK why this fixed it but maybe it's because torch is trying to compute gradient for every expert at once, which shouldn't happen since it's a MoE model.
Databricks org
Databricks org

Error 1. should be fixed in this PR.
We are currently working on fixing Error 2 in the same PR. For a current workaround please see: huggingface.co/databricks/dbrx-instruct/discussions/10#660566f14f41c0c7c0e54ab9

Sign up or log in to comment