zxdu20 commited on
Commit
4de8efe
1 Parent(s): 3a99d79

Change mask positions to batch

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +21 -11
modeling_chatglm.py CHANGED
@@ -689,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
689
 
690
  return attention_mask
691
 
692
- def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
693
  batch_size, seq_length = input_ids.shape
 
 
694
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
695
  if self.position_encoding_2d:
696
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
@@ -704,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
704
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
705
  else:
706
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
707
- if not gmask:
708
- for i, context_length in enumerate(context_lengths):
709
  position_ids[context_length:] = mask_positions[i]
710
 
711
  return position_ids
@@ -939,15 +941,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
939
 
940
  if position_ids is None:
941
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
942
- mask_token = gMASK if gMASK in input_ids else MASK
943
- use_gmask = True if gMASK in input_ids else False
 
 
 
 
 
 
944
 
945
- mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
946
  position_ids = self.get_position_ids(
947
  input_ids,
948
  mask_positions=mask_positions,
949
  device=input_ids.device,
950
- gmask=use_gmask
951
  )
952
 
953
  if self.pre_seq_len is not None and attention_mask is not None:
@@ -1106,10 +1113,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1106
  ) -> dict:
1107
  batch_size, seq_length = input_ids.shape
1108
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1109
- mask_token = gMASK if gMASK in input_ids else MASK
1110
- use_gmask = True if gMASK in input_ids else False
1111
  seqs = input_ids.tolist()
1112
- mask_positions = [seq.index(mask_token) for seq in seqs]
 
 
 
 
 
1113
 
1114
  # only last token for input_ids if past is not None
1115
  if past is not None or past_key_values is not None:
@@ -1152,7 +1162,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1152
  input_ids,
1153
  device=input_ids.device,
1154
  mask_positions=mask_positions,
1155
- gmask=use_gmask
1156
  )
1157
 
1158
  return {
 
689
 
690
  return attention_mask
691
 
692
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
693
  batch_size, seq_length = input_ids.shape
694
+ if use_gmasks is None:
695
+ use_gmasks = [False] * batch_size
696
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
697
  if self.position_encoding_2d:
698
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
 
706
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
707
  else:
708
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
709
+ for i, context_length in enumerate(context_lengths):
710
+ if not use_gmasks[i]:
711
  position_ids[context_length:] = mask_positions[i]
712
 
713
  return position_ids
 
941
 
942
  if position_ids is None:
943
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
944
+ seqs = input_ids.tolist()
945
+
946
+ mask_positions, use_gmasks = [], []
947
+ for seq in seqs:
948
+ mask_token = gMASK if gMASK in seq else MASK
949
+ use_gmask = mask_token == gMASK
950
+ mask_positions.append(seq.index(mask_token))
951
+ use_gmasks.append(use_gmask)
952
 
 
953
  position_ids = self.get_position_ids(
954
  input_ids,
955
  mask_positions=mask_positions,
956
  device=input_ids.device,
957
+ use_gmasks=use_gmasks
958
  )
959
 
960
  if self.pre_seq_len is not None and attention_mask is not None:
 
1113
  ) -> dict:
1114
  batch_size, seq_length = input_ids.shape
1115
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
 
 
1116
  seqs = input_ids.tolist()
1117
+ mask_positions, use_gmasks = [], []
1118
+ for seq in seqs:
1119
+ mask_token = gMASK if gMASK in seq else MASK
1120
+ use_gmask = mask_token == gMASK
1121
+ mask_positions.append(seq.index(mask_token))
1122
+ use_gmasks.append(use_gmask)
1123
 
1124
  # only last token for input_ids if past is not None
1125
  if past is not None or past_key_values is not None:
 
1162
  input_ids,
1163
  device=input_ids.device,
1164
  mask_positions=mask_positions,
1165
+ use_gmasks=use_gmasks
1166
  )
1167
 
1168
  return {