Error in Flash Attention
Hi, I am getting the following error while running backpropagation on it. I updated my codebase based on the suggestion here, but the issue persists. https://huggingface.co/microsoft/Phi-3-small-128k-instruct/commit/ed7de9a074b0760e6cf050fe1d103b90834933c8
new block_sparse_attn op constructed with config: n_heads=32, max_seq_len=131072, sparse_block_size=64, local_blocks=16, vert_stride=8, homo_head=False, active_head_range=None, kwargs={'kernel_block_size': 64, 'inference': True}
Traceback (most recent call last):
File ".../Code/Phi3/Phi3-C4-small-L-Cosine-Masked-All.py", line 257, in
loss_sum.backward()
File "....conda/envs/demo/lib/python3.12/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File ".../.conda/envs/demo/lib/python3.12/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File ".../.conda/envs/demo/lib/python3.12/site-packages/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File ".../huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/ad85cab62be398dc90203c4377a4ccbf090fbb36/triton_flash_blocksparse_attn.py", line 906, in backward
return _backward(ctx, do, *backward_layout)[:4]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/ad85cab62be398dc90203c4377a4ccbf090fbb36/triton_flash_blocksparse_attn.py", line 683, in _backward
delta = torch.empty_like(l)
^^^^^^^^^^^^^^^^^^^
TypeError: empty_like(): argument 'input' (position 1) must be Tensor, not NoneType