`long_factor` is never used?

#22
by J22 - opened

long_factor is never used, since inv_freq is likely to be always initialized by short_factor. Is there anything wrong?

    

@torch
	.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        position_ids_expanded = position_ids[:, None, :].float()
        if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
            ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
        else:
            ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)

        if self.inv_freq is None:
            self.inv_freq = 1.0 / (
                ext_factors
                * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
            )
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
Microsoft org
edited Apr 24

Sorry about that, we are addressing the required changes for fully integrating in transformers and we missed this. It is fixed now and will be improved to be cached later on.

gugarosa changed discussion status to closed

This is still confusion. Suppose when this is called for the first time, 5000 tokens are passed in , then, long_factor is also used the first 4096 tokens. Is this intentional?

    

@torch
	.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        seq_len = torch.max(position_ids) + 1
        if seq_len > self.original_max_position_embeddings:
            ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
        else:
            ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)

Sign up or log in to comment