Add pad_token_id in config.json
Browse filesFix position_ids in ChatGLMModel
Add batch position_ids
- config.json +1 -0
- modeling_chatglm.py +31 -30
config.json
CHANGED
@@ -10,6 +10,7 @@
|
|
10 |
},
|
11 |
"bos_token_id": 150004,
|
12 |
"eos_token_id": 150005,
|
|
|
13 |
"hidden_size": 4096,
|
14 |
"inner_hidden_size": 16384,
|
15 |
"layernorm_epsilon": 1e-05,
|
|
|
10 |
},
|
11 |
"bos_token_id": 150004,
|
12 |
"eos_token_id": 150005,
|
13 |
+
"pad_token_id": 20003,
|
14 |
"hidden_size": 4096,
|
15 |
"inner_hidden_size": 16384,
|
16 |
"layernorm_epsilon": 1e-05,
|
modeling_chatglm.py
CHANGED
@@ -850,8 +850,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
850 |
for i, context_length in enumerate(context_lengths):
|
851 |
position_ids[context_length:] = mask_positions[i]
|
852 |
|
853 |
-
position_ids = position_ids.unsqueeze(0)
|
854 |
-
|
855 |
return position_ids
|
856 |
|
857 |
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@@ -1007,29 +1005,34 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1007 |
def set_output_embeddings(self, new_embeddings):
|
1008 |
self.lm_head = new_embeddings
|
1009 |
|
1010 |
-
def get_masks_and_position_ids(self,
|
1011 |
-
|
|
|
|
|
1012 |
attention_mask.tril_()
|
1013 |
-
|
|
|
1014 |
attention_mask.unsqueeze_(1)
|
1015 |
attention_mask = (attention_mask < 0.5).bool()
|
1016 |
|
|
|
|
|
1017 |
if self.position_encoding_2d:
|
1018 |
-
|
1019 |
-
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
1020 |
if not gmask:
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
torch.
|
1025 |
-
|
1026 |
-
|
|
|
|
|
1027 |
else:
|
1028 |
-
position_ids = torch.arange(
|
1029 |
if not gmask:
|
1030 |
-
|
1031 |
-
|
1032 |
-
position_ids = position_ids.unsqueeze(0)
|
1033 |
|
1034 |
return attention_mask, position_ids
|
1035 |
|
@@ -1041,25 +1044,24 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1041 |
attention_mask: Optional[torch.Tensor] = None,
|
1042 |
**kwargs
|
1043 |
) -> dict:
|
1044 |
-
|
1045 |
MASK, gMASK = 150000, 150001
|
1046 |
mask_token = MASK if MASK in input_ids else gMASK
|
1047 |
use_gmask = False if MASK in input_ids else gMASK
|
1048 |
-
|
1049 |
-
|
1050 |
-
|
1051 |
-
if mask_token not in seq:
|
1052 |
-
raise ValueError("You have to add either [MASK] or [gMASK] in your input")
|
1053 |
|
1054 |
# only last token for input_ids if past is not None
|
1055 |
if past is not None or past_key_values is not None:
|
1056 |
-
|
1057 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1058 |
if self.position_encoding_2d:
|
1059 |
-
position_ids = torch.tensor(
|
1060 |
-
|
|
|
1061 |
else:
|
1062 |
-
position_ids = torch.tensor([
|
|
|
1063 |
|
1064 |
if past is None:
|
1065 |
past = past_key_values
|
@@ -1070,9 +1072,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1070 |
}
|
1071 |
else:
|
1072 |
attention_mask, position_ids = self.get_masks_and_position_ids(
|
1073 |
-
|
1074 |
-
|
1075 |
-
context_length=len(seq),
|
1076 |
device=input_ids.device,
|
1077 |
gmask=use_gmask
|
1078 |
)
|
|
|
850 |
for i, context_length in enumerate(context_lengths):
|
851 |
position_ids[context_length:] = mask_positions[i]
|
852 |
|
|
|
|
|
853 |
return position_ids
|
854 |
|
855 |
@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
|
|
1005 |
def set_output_embeddings(self, new_embeddings):
|
1006 |
self.lm_head = new_embeddings
|
1007 |
|
1008 |
+
def get_masks_and_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
1009 |
+
batch_size, seq_length = input_ids.shape
|
1010 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
1011 |
+
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
1012 |
attention_mask.tril_()
|
1013 |
+
for i, context_length in enumerate(context_lengths):
|
1014 |
+
attention_mask[i, :, :context_length] = 1
|
1015 |
attention_mask.unsqueeze_(1)
|
1016 |
attention_mask = (attention_mask < 0.5).bool()
|
1017 |
|
1018 |
+
batch_size, seq_length = input_ids.shape
|
1019 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
1020 |
if self.position_encoding_2d:
|
1021 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
|
|
1022 |
if not gmask:
|
1023 |
+
for i, context_length in enumerate(context_lengths):
|
1024 |
+
position_ids[i, context_length:] = mask_positions[i]
|
1025 |
+
block_position_ids = [torch.cat((
|
1026 |
+
torch.zeros(context_length, dtype=torch.long, device=device),
|
1027 |
+
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
1028 |
+
)) for context_length in context_lengths]
|
1029 |
+
block_position_ids = torch.stack(block_position_ids, dim=0)
|
1030 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
1031 |
else:
|
1032 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
1033 |
if not gmask:
|
1034 |
+
for i, context_length in enumerate(context_lengths):
|
1035 |
+
position_ids[context_length:] = mask_positions[i]
|
|
|
1036 |
|
1037 |
return attention_mask, position_ids
|
1038 |
|
|
|
1044 |
attention_mask: Optional[torch.Tensor] = None,
|
1045 |
**kwargs
|
1046 |
) -> dict:
|
1047 |
+
batch_size, seq_length = input_ids.shape
|
1048 |
MASK, gMASK = 150000, 150001
|
1049 |
mask_token = MASK if MASK in input_ids else gMASK
|
1050 |
use_gmask = False if MASK in input_ids else gMASK
|
1051 |
+
seqs = input_ids.tolist()
|
1052 |
+
mask_positions = [seq.index(mask_token) for seq in seqs]
|
|
|
|
|
|
|
1053 |
|
1054 |
# only last token for input_ids if past is not None
|
1055 |
if past is not None or past_key_values is not None:
|
1056 |
+
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
1057 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
1058 |
if self.position_encoding_2d:
|
1059 |
+
position_ids = torch.tensor(
|
1060 |
+
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
1061 |
+
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
1062 |
else:
|
1063 |
+
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
1064 |
+
device=input_ids.device).unsqueeze(-1)
|
1065 |
|
1066 |
if past is None:
|
1067 |
past = past_key_values
|
|
|
1072 |
}
|
1073 |
else:
|
1074 |
attention_mask, position_ids = self.get_masks_and_position_ids(
|
1075 |
+
input_ids,
|
1076 |
+
mask_positions=mask_positions,
|
|
|
1077 |
device=input_ids.device,
|
1078 |
gmask=use_gmask
|
1079 |
)
|