[bug] get_position_ids
#11
by
kebo
- opened
def get_position_ids(self, seq, mask_position, device, gmask=False):
context_length = seq.index(self.config.bos_token_id) + 1
if self.position_encoding_2d:
seq_length = seq.index(self.config.bos_token_id)
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[seq_length:] = mask_position
block_position_ids = torch.cat((
torch.zeros(seq_length, dtype=torch.long, device=device),
torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
))
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
else:
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
if not gmask:
position_ids[context_length - 1:] = mask_position
position_ids = position_ids.unsqueeze(0)
return position_ids
context_length = seq.index(self.config.bos_token_id) + 1
seq_length = seq.index(self.config.bos_token_id)
? context_length - seq_length == 1 ? why?
zxdu20
changed discussion status to
closed