[bug] get_position_ids

#11
by kebo - opened
    def get_position_ids(self, seq, mask_position, device, gmask=False):
        context_length = seq.index(self.config.bos_token_id) + 1
        if self.position_encoding_2d:
            seq_length = seq.index(self.config.bos_token_id)
            position_ids = torch.arange(context_length, dtype=torch.long, device=device)
            if not gmask:
                position_ids[seq_length:] = mask_position
            block_position_ids = torch.cat((
                torch.zeros(seq_length, dtype=torch.long, device=device),
                torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
            ))
            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
        else:
            position_ids = torch.arange(context_length, dtype=torch.long, device=device)
            if not gmask:
                position_ids[context_length - 1:] = mask_position

        position_ids = position_ids.unsqueeze(0)

        return position_ids

context_length = seq.index(self.config.bos_token_id) + 1
seq_length = seq.index(self.config.bos_token_id) ? context_length - seq_length == 1 ? why?

zxdu20 changed discussion status to closed

Sign up or log in to comment