zxdu20 commited on
Commit
a8ede82
1 Parent(s): f831824

Fix position ids in 1d position encoding

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -708,7 +708,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
708
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
709
  for i, context_length in enumerate(context_lengths):
710
  if not use_gmasks[i]:
711
- position_ids[context_length:] = mask_positions[i]
712
 
713
  return position_ids
714
 
708
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
709
  for i, context_length in enumerate(context_lengths):
710
  if not use_gmasks[i]:
711
+ position_ids[i, context_length:] = mask_positions[i]
712
 
713
  return position_ids
714