Merge remote-tracking branch 'thu/main'
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
@@ -770,7 +770,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
770 |
return attention_mask
|
771 |
|
772 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
773 |
-
context_length = seq
|
774 |
if self.position_encoding_2d:
|
775 |
seq_length = seq.index(self.config.bos_token_id)
|
776 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
|
|
770 |
return attention_mask
|
771 |
|
772 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
773 |
+
context_length = len(seq)
|
774 |
if self.position_encoding_2d:
|
775 |
seq_length = seq.index(self.config.bos_token_id)
|
776 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|