duzx16
commited on
Commit
•
f81daa3
1
Parent(s):
7bcdc71
Fix batch generation for vision model
Browse files- modeling_chatglm.py +20 -5
modeling_chatglm.py
CHANGED
@@ -692,16 +692,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
692 |
"""Initialize the weights."""
|
693 |
return
|
694 |
|
695 |
-
def get_masks(self,
|
696 |
-
batch_size, seq_length =
|
697 |
-
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=
|
698 |
full_attention_mask.tril_()
|
699 |
past_length = 0
|
700 |
if past_key_values:
|
701 |
past_length = past_key_values[0][0].shape[2]
|
702 |
if past_length:
|
703 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
704 |
-
device=
|
705 |
if padding_mask is not None:
|
706 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
707 |
if not past_length and padding_mask is not None:
|
@@ -887,7 +887,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
887 |
|
888 |
if full_attention_mask is None:
|
889 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
890 |
-
full_attention_mask = self.get_masks(
|
891 |
|
892 |
# Rotary positional embeddings
|
893 |
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
@@ -976,6 +976,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
976 |
# only last token for input_ids if past is not None
|
977 |
if position_ids is None:
|
978 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
979 |
if not is_first_forward:
|
980 |
if past_key_values is not None:
|
981 |
position_ids = position_ids[..., -1:]
|
|
|
692 |
"""Initialize the weights."""
|
693 |
return
|
694 |
|
695 |
+
def get_masks(self, input_embeds, past_key_values, padding_mask=None):
|
696 |
+
batch_size, seq_length, embed_size = input_embeds.shape
|
697 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
|
698 |
full_attention_mask.tril_()
|
699 |
past_length = 0
|
700 |
if past_key_values:
|
701 |
past_length = past_key_values[0][0].shape[2]
|
702 |
if past_length:
|
703 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
704 |
+
device=input_embeds.device), full_attention_mask), dim=-1)
|
705 |
if padding_mask is not None:
|
706 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
707 |
if not past_length and padding_mask is not None:
|
|
|
887 |
|
888 |
if full_attention_mask is None:
|
889 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
890 |
+
full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
|
891 |
|
892 |
# Rotary positional embeddings
|
893 |
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
|
|
976 |
# only last token for input_ids if past is not None
|
977 |
if position_ids is None:
|
978 |
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
|
979 |
+
if attention_mask is not None:
|
980 |
+
image_size: int = self.config.vision_config['image_size']
|
981 |
+
patch_size: int = self.config.vision_config['patch_size']
|
982 |
+
num_patches = (image_size // patch_size // 2) ** 2
|
983 |
+
new_attention_masks = []
|
984 |
+
for i in range(len(input_ids)):
|
985 |
+
input_id = input_ids[i].tolist()
|
986 |
+
boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
|
987 |
+
self.config.eoi_token_id)
|
988 |
+
assert eoi_token_pos - boi_token_pos == 2
|
989 |
+
new_attention_masks.append(torch.cat(
|
990 |
+
(attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
|
991 |
+
attention_mask[i, eoi_token_pos:])
|
992 |
+
))
|
993 |
+
attention_mask = torch.stack(new_attention_masks, dim=0)
|
994 |
if not is_first_forward:
|
995 |
if past_key_values is not None:
|
996 |
position_ids = position_ids[..., -1:]
|