Why navit version slower than normal version?

Reposting this discussion from @yuzaa because I deleted the debug repo they create this question from:

I found that the forward speed of the navit version is twice as slow at the same resolution. (GPU: A800)

import torch
from transformers import AutoModel
base = AutoModel.from_pretrained("HuggingFaceM4/siglip-so400m-14-384-flash-attn2", trust_remote_code=True)
navit = AutoModel.from_pretrained("HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit", trust_remote_code=True)

base_vision = base.vision_model

navit_vision = navit.vision_model

pixel_values = torch.randn(1, 3, 384, 384).bfloat16().cuda()

# %%time
for i in range(100):
    x = base_vision(pixel_values)

# CPU times: user 1.21 s, sys: 12.4 ms, total: 1.22 s
# Wall time: 1.22 s

# %%time
for i in range(100):
    x = navit_vision(pixel_values)

# CPU times: user 2.63 s, sys: 36.3 ms, total: 2.66 s
# Wall time: 2.66 s
I don't quite know yet, will dig in this week. there should not be such a speed overhead...

Note that this model has not been trained yet after the position embedding have been interpolated and the navit style handling of images introduced.

So it looks like the flash_attn_varlen_func and flash_attn_func paths of flash attention 2 (one requires an attention mask, the other one is the behavior when no attention mask is passed) have different speeds.
the call to _upad_input is expensive when passing the attention_mask
i am fixing this now

