duzx16 commited on
Commit
f81daa3
1 Parent(s): 7bcdc71

Fix batch generation for vision model

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +20 -5
modeling_chatglm.py CHANGED
@@ -692,16 +692,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
692
  """Initialize the weights."""
693
  return
694
 
695
- def get_masks(self, input_ids, past_key_values, padding_mask=None):
696
- batch_size, seq_length = input_ids.shape
697
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
698
  full_attention_mask.tril_()
699
  past_length = 0
700
  if past_key_values:
701
  past_length = past_key_values[0][0].shape[2]
702
  if past_length:
703
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
704
- device=input_ids.device), full_attention_mask), dim=-1)
705
  if padding_mask is not None:
706
  full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
707
  if not past_length and padding_mask is not None:
@@ -887,7 +887,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
887
 
888
  if full_attention_mask is None:
889
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
890
- full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
891
 
892
  # Rotary positional embeddings
893
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
@@ -976,6 +976,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
976
  # only last token for input_ids if past is not None
977
  if position_ids is None:
978
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
979
  if not is_first_forward:
980
  if past_key_values is not None:
981
  position_ids = position_ids[..., -1:]
 
692
  """Initialize the weights."""
693
  return
694
 
695
+ def get_masks(self, input_embeds, past_key_values, padding_mask=None):
696
+ batch_size, seq_length, embed_size = input_embeds.shape
697
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_embeds.device)
698
  full_attention_mask.tril_()
699
  past_length = 0
700
  if past_key_values:
701
  past_length = past_key_values[0][0].shape[2]
702
  if past_length:
703
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
704
+ device=input_embeds.device), full_attention_mask), dim=-1)
705
  if padding_mask is not None:
706
  full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
707
  if not past_length and padding_mask is not None:
 
887
 
888
  if full_attention_mask is None:
889
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
890
+ full_attention_mask = self.get_masks(inputs_embeds, past_key_values, padding_mask=attention_mask)
891
 
892
  # Rotary positional embeddings
893
  rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
 
976
  # only last token for input_ids if past is not None
977
  if position_ids is None:
978
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
979
+ if attention_mask is not None:
980
+ image_size: int = self.config.vision_config['image_size']
981
+ patch_size: int = self.config.vision_config['patch_size']
982
+ num_patches = (image_size // patch_size // 2) ** 2
983
+ new_attention_masks = []
984
+ for i in range(len(input_ids)):
985
+ input_id = input_ids[i].tolist()
986
+ boi_token_pos, eoi_token_pos = input_id.index(self.config.boi_token_id), input_id.index(
987
+ self.config.eoi_token_id)
988
+ assert eoi_token_pos - boi_token_pos == 2
989
+ new_attention_masks.append(torch.cat(
990
+ (attention_mask[i, :boi_token_pos + 1], attention_mask.new_ones(num_patches),
991
+ attention_mask[i, eoi_token_pos:])
992
+ ))
993
+ attention_mask = torch.stack(new_attention_masks, dim=0)
994
  if not is_first_forward:
995
  if past_key_values is not None:
996
  position_ids = position_ids[..., -1:]