Fix position id for training
Browse files- modeling_chatglm.py +4 -6
modeling_chatglm.py
CHANGED
@@ -845,9 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
845 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
846 |
if self.position_encoding_2d:
|
847 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
848 |
-
|
849 |
-
|
850 |
-
position_ids[i, context_length:] = mask_positions[i]
|
851 |
block_position_ids = [torch.cat((
|
852 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
853 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
@@ -1053,9 +1052,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1053 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
1054 |
if self.position_encoding_2d:
|
1055 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
1056 |
-
|
1057 |
-
|
1058 |
-
position_ids[i, context_length:] = mask_positions[i]
|
1059 |
block_position_ids = [torch.cat((
|
1060 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
1061 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
|
|
845 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
846 |
if self.position_encoding_2d:
|
847 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
848 |
+
for i, context_length in enumerate(context_lengths):
|
849 |
+
position_ids[i, context_length:] = mask_positions[i]
|
|
|
850 |
block_position_ids = [torch.cat((
|
851 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
852 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
|
|
1052 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
1053 |
if self.position_encoding_2d:
|
1054 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
1055 |
+
for i, context_length in enumerate(context_lengths):
|
1056 |
+
position_ids[i, context_length:] = mask_positions[i]
|
|
|
1057 |
block_position_ids = [torch.cat((
|
1058 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
1059 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|