silver commited on
Commit
a6d4a44
2 Parent(s): a748e08 096f3de

Merge remote-tracking branch 'thu/main'

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -770,7 +770,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
770
  return attention_mask
771
 
772
  def get_position_ids(self, seq, mask_position, device, gmask=False):
773
- context_length = seq.index(self.config.bos_token_id) + 1
774
  if self.position_encoding_2d:
775
  seq_length = seq.index(self.config.bos_token_id)
776
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
 
770
  return attention_mask
771
 
772
  def get_position_ids(self, seq, mask_position, device, gmask=False):
773
+ context_length = len(seq)
774
  if self.position_encoding_2d:
775
  seq_length = seq.index(self.config.bos_token_id)
776
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)