zxdu20 commited on
Commit
2200e2b
1 Parent(s): db22499

Add pad_token_id in config.json

Browse files

Fix position_ids in ChatGLMModel
Add batch position_ids

Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +31 -30
config.json CHANGED
@@ -10,6 +10,7 @@
10
  },
11
  "bos_token_id": 150004,
12
  "eos_token_id": 150005,
 
13
  "hidden_size": 4096,
14
  "inner_hidden_size": 16384,
15
  "layernorm_epsilon": 1e-05,
 
10
  },
11
  "bos_token_id": 150004,
12
  "eos_token_id": 150005,
13
+ "pad_token_id": 20003,
14
  "hidden_size": 4096,
15
  "inner_hidden_size": 16384,
16
  "layernorm_epsilon": 1e-05,
modeling_chatglm.py CHANGED
@@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
850
  for i, context_length in enumerate(context_lengths):
851
  position_ids[context_length:] = mask_positions[i]
852
 
853
- position_ids = position_ids.unsqueeze(0)
854
-
855
  return position_ids
856
 
857
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1007
  def set_output_embeddings(self, new_embeddings):
1008
  self.lm_head = new_embeddings
1009
 
1010
- def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False):
1011
- attention_mask = torch.ones((1, context_length, context_length), device=device)
 
 
1012
  attention_mask.tril_()
1013
- attention_mask[..., :context_length - 1] = 1
 
1014
  attention_mask.unsqueeze_(1)
1015
  attention_mask = (attention_mask < 0.5).bool()
1016
 
 
 
1017
  if self.position_encoding_2d:
1018
- seq_length = seq.index(self.config.bos_token_id)
1019
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
1020
  if not gmask:
1021
- position_ids[seq_length:] = mask_position
1022
- block_position_ids = torch.cat((
1023
- torch.zeros(seq_length, dtype=torch.long, device=device),
1024
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
1025
- ))
1026
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
 
 
1027
  else:
1028
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
1029
  if not gmask:
1030
- position_ids[context_length - 1:] = mask_position
1031
-
1032
- position_ids = position_ids.unsqueeze(0)
1033
 
1034
  return attention_mask, position_ids
1035
 
@@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1041
  attention_mask: Optional[torch.Tensor] = None,
1042
  **kwargs
1043
  ) -> dict:
1044
-
1045
  MASK, gMASK = 150000, 150001
1046
  mask_token = MASK if MASK in input_ids else gMASK
1047
  use_gmask = False if MASK in input_ids else gMASK
1048
- seq = input_ids[0].tolist()
1049
- mask_position = seq.index(mask_token)
1050
-
1051
- if mask_token not in seq:
1052
- raise ValueError("You have to add either [MASK] or [gMASK] in your input")
1053
 
1054
  # only last token for input_ids if past is not None
1055
  if past is not None or past_key_values is not None:
1056
- context_length = seq.index(self.config.bos_token_id)
1057
  last_token = input_ids[:, -1].unsqueeze(-1)
1058
  if self.position_encoding_2d:
1059
- position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
1060
- device=input_ids.device)
 
1061
  else:
1062
- position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device)
 
1063
 
1064
  if past is None:
1065
  past = past_key_values
@@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1070
  }
1071
  else:
1072
  attention_mask, position_ids = self.get_masks_and_position_ids(
1073
- seq=seq,
1074
- mask_position=mask_position,
1075
- context_length=len(seq),
1076
  device=input_ids.device,
1077
  gmask=use_gmask
1078
  )
 
850
  for i, context_length in enumerate(context_lengths):
851
  position_ids[context_length:] = mask_positions[i]
852
 
 
 
853
  return position_ids
854
 
855
  @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
 
1005
  def set_output_embeddings(self, new_embeddings):
1006
  self.lm_head = new_embeddings
1007
 
1008
+ def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
1009
+ batch_size, seq_length = input_ids.shape
1010
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1011
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
1012
  attention_mask.tril_()
1013
+ for i, context_length in enumerate(context_lengths):
1014
+ attention_mask[i, :, :context_length] = 1
1015
  attention_mask.unsqueeze_(1)
1016
  attention_mask = (attention_mask < 0.5).bool()
1017
 
1018
+ batch_size, seq_length = input_ids.shape
1019
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
1020
  if self.position_encoding_2d:
1021
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
 
1022
  if not gmask:
1023
+ for i, context_length in enumerate(context_lengths):
1024
+ position_ids[i, context_length:] = mask_positions[i]
1025
+ block_position_ids = [torch.cat((
1026
+ torch.zeros(context_length, dtype=torch.long, device=device),
1027
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
1028
+ )) for context_length in context_lengths]
1029
+ block_position_ids = torch.stack(block_position_ids, dim=0)
1030
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
1031
  else:
1032
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
1033
  if not gmask:
1034
+ for i, context_length in enumerate(context_lengths):
1035
+ position_ids[context_length:] = mask_positions[i]
 
1036
 
1037
  return attention_mask, position_ids
1038
 
 
1044
  attention_mask: Optional[torch.Tensor] = None,
1045
  **kwargs
1046
  ) -> dict:
1047
+ batch_size, seq_length = input_ids.shape
1048
  MASK, gMASK = 150000, 150001
1049
  mask_token = MASK if MASK in input_ids else gMASK
1050
  use_gmask = False if MASK in input_ids else gMASK
1051
+ seqs = input_ids.tolist()
1052
+ mask_positions = [seq.index(mask_token) for seq in seqs]
 
 
 
1053
 
1054
  # only last token for input_ids if past is not None
1055
  if past is not None or past_key_values is not None:
1056
+ context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1057
  last_token = input_ids[:, -1].unsqueeze(-1)
1058
  if self.position_encoding_2d:
1059
+ position_ids = torch.tensor(
1060
+ [[mask_position, seq_length - context_length] for mask_position, context_length in
1061
+ zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
1062
  else:
1063
+ position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
1064
+ device=input_ids.device).unsqueeze(-1)
1065
 
1066
  if past is None:
1067
  past = past_key_values
 
1072
  }
1073
  else:
1074
  attention_mask, position_ids = self.get_masks_and_position_ids(
1075
+ input_ids,
1076
+ mask_positions=mask_positions,
 
1077
  device=input_ids.device,
1078
  gmask=use_gmask
1079
  )