Using item() makes torch.export not possible

#61
by bartel97 - opened

Hi!
In the classes Phi3SuScaledRotaryEmbedding and Phi3YarnScaledRotaryEmbedding we see this kind of pattern:

def forward(self, x, position_ids, seq_len=None):
    seq_len = torch.max(position_ids) + 1
    if seq_len > self.original_max_position_embeddings:
        ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
    else:
        ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)

torch.export.export() breaks at if seq_len > self.original_max_position_embeddings: because seq_len is data dependent and we try to branch on it. Two questions:

  1. Why are we recomputing seq_len if we pass it to the function? In all uses it is always passes. Seems to me that you can remove the optionality of it and not recompute it. I think this will not fix following problem in flash attention though. Here we branch again on a data dependency.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
  1. Is there no way of getting the sequence length without using max() on the tensor? If we need to use it, is there a way to branch on something else that is not data dependent?

For reference, I am using this https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html torch api and seeing the error described in https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit#heading=h.r02234kuof4f. The google docs also describes strategies that might help here.

Sign up or log in to comment