FSDP Finetuning

#12
by cchristophe - opened

Hi I'm trying to finetune Mixtral using FSDP framework and I have this error during the first backward pass:
Exception: Attention mask should be of size (1, 1, 4096, 8192), but is torch.Size([1, 1, 4096, 4096])

I'm using the same logic and the same data I used to finetune Mistral 7B...

Getting this error as well.

Thanks, could you open an issue on https://huggingface/transformers with a full reproducer?

Is there any corresponding issue?

I believe it has been recently fixed by: https://github.com/huggingface/transformers/pull/28061
You can use the main branch of transformers, pip install -U git+https://github.com/huggingface/transformers.git

@ybelkada I can confirm that with moving to the latest HF as mentioned above, I am able to fine tune Mixtral using FSDP. :tada:

@rganti Can you please share your FSDP config ?
I am trying a full fine tuning(not LoRA) using
auto_wrap_policy={MixtralDecoderLayer}, activation_checkpointing_policy={MixtralDecoderLayer}
according to https://lightning.ai/docs/fabric/stable/advanced/model_parallel/fsdp.html

It is giving me recomputed tensor size mismatch error. A detailed bug report is here
FYI: I tried the latest transformer and lightning library installed from git+https

{
  "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
  "fsdp_backward_prefetch_policy": "BACKWARD_PRE",
  "fsdp_cpu_ram_efficient_loading": "False",
  "fsdp_forward_prefetch": "True",
  "fsdp_offload_params": "False",
  "fsdp_state_dict_type": "SHARDED_STATE_DICT",
  "fsdp_sync_module_states": "False",
  "fsdp_transformer_layer_cls_to_wrap": "MixtralDecoderLayer",
  "fsdp_use_orig_params": "True",
  "activation_checkpointing": "True"
}

I am using SFTTrainer

btw @hrushikesh1 -- some other model (GPTBigCode) is giving me this trouble (size/shape mismatch), it used to work well in the past for me :)

@hrushikesh1 To update, it seems to be flaky and dependent on the PyTorch and HF versions that are installed. I am still trying to figure out the "right" combination, but perhaps @ybelkada or someone from HF/PT teams can comment?

specifically, using torch version 2.2.0.dev20231121+cu118 and transformers is 4.37.0.dev0 and python is 3.11

Thanks for the info @rganti !
I was able to solve it by explicitly calling
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': True})
after model load AutoModel.from_pretrained().

The issue re-appears if I set use_reentrant:False in the above call. Lightning library might be defaulting to use_reentrant:False.

There is lot of notes and warning from pytorch on the renentrant behavior here
As of torch 2.1 it defaults to True, but they plan to move to use_reentrant=False as a default in future, that might be causing the flakiness you observe across versions

@hrushikesh1 I was able to Lora tune mixtral on the latest PT nightlies and latest HF main after adding the above line, thanks!

Wanted to share a note for a future data scientist in trouble:
I was trying LORA fine tuning of Mistral-7B using FSDP strategy and pytorch lighting trainer. It used to get stuck at Step-1.
Turned out, since there are some frozen parameters without gradients, I can not use gradient_clipping.

{
  "fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
  "fsdp_backward_prefetch_policy": "BACKWARD_PRE",
  "fsdp_cpu_ram_efficient_loading": "False",
  "fsdp_forward_prefetch": "True",
  "fsdp_offload_params": "False",
  "fsdp_state_dict_type": "SHARDED_STATE_DICT",
  "fsdp_sync_module_states": "False",
  "fsdp_transformer_layer_cls_to_wrap": "MixtralDecoderLayer",
  "fsdp_use_orig_params": "True",
  "activation_checkpointing": "True"
}

I am using SFTTrainer

Hi, Did you fine-tune Mixtral 8x7b with any adapter ? Or just regular fine-tuning with FSDP. Can you provide your GPU computing resource info ?

Sign up or log in to comment