zxdu20 commited on
Commit
096f3de
1 Parent(s): 4a9b711

Fix context length in get_position_ids

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
769
  return attention_mask
770
 
771
  def get_position_ids(self, seq, mask_position, device, gmask=False):
772
- context_length = seq.index(self.config.bos_token_id) + 1
773
  if self.position_encoding_2d:
774
  seq_length = seq.index(self.config.bos_token_id)
775
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
 
769
  return attention_mask
770
 
771
  def get_position_ids(self, seq, mask_position, device, gmask=False):
772
+ context_length = len(seq)
773
  if self.position_encoding_2d:
774
  seq_length = seq.index(self.config.bos_token_id)
775
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)