zxdu20 commited on
Commit
11c270c
1 Parent(s): 9c7416d

Fix position id for training

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +4 -6
modeling_chatglm.py CHANGED
@@ -845,9 +845,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
845
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
846
  if self.position_encoding_2d:
847
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
848
- if not gmask:
849
- for i, context_length in enumerate(context_lengths):
850
- position_ids[i, context_length:] = mask_positions[i]
851
  block_position_ids = [torch.cat((
852
  torch.zeros(context_length, dtype=torch.long, device=device),
853
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
@@ -1053,9 +1052,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1053
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1054
  if self.position_encoding_2d:
1055
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
1056
- if not gmask:
1057
- for i, context_length in enumerate(context_lengths):
1058
- position_ids[i, context_length:] = mask_positions[i]
1059
  block_position_ids = [torch.cat((
1060
  torch.zeros(context_length, dtype=torch.long, device=device),
1061
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
 
845
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
846
  if self.position_encoding_2d:
847
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
848
+ for i, context_length in enumerate(context_lengths):
849
+ position_ids[i, context_length:] = mask_positions[i]
 
850
  block_position_ids = [torch.cat((
851
  torch.zeros(context_length, dtype=torch.long, device=device),
852
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
 
1052
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1053
  if self.position_encoding_2d:
1054
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
1055
+ for i, context_length in enumerate(context_lengths):
1056
+ position_ids[i, context_length:] = mask_positions[i]
 
1057
  block_position_ids = [torch.cat((
1058
  torch.zeros(context_length, dtype=torch.long, device=device),
1059
  torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1