Fix: AttributeError when `input_ids` is None during multimodal LLM training

#77
by lyulumos - opened
Files changed (1) hide show
  1. modeling_chatglm.py +5 -4
modeling_chatglm.py CHANGED
@@ -771,15 +771,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
771
  if padding_mask is not None and not padding_mask.all():
772
  return padding_mask
773
  return None
774
- batch_size, seq_length = input_ids.shape
775
- full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
 
776
  full_attention_mask.tril_()
777
  past_length = 0
778
  if past_key_values:
779
  past_length = past_key_values[0][0].shape[2]
780
  if past_length:
781
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
782
- device=input_ids.device), full_attention_mask), dim=-1)
783
  if padding_mask is not None:
784
  full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
785
  if not past_length and padding_mask is not None:
@@ -872,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
872
  use_cache = use_cache if use_cache is not None else self.config.use_cache
873
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
874
 
875
- batch_size, seq_length = input_ids.shape
876
 
877
  if inputs_embeds is None:
878
  inputs_embeds = self.embedding(input_ids)
 
771
  if padding_mask is not None and not padding_mask.all():
772
  return padding_mask
773
  return None
774
+ batch_size, seq_length = input_ids.shape if input_ids is not None else padding_mask.shape
775
+ device = input_ids.device if input_ids is not None else padding_mask.device
776
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=device)
777
  full_attention_mask.tril_()
778
  past_length = 0
779
  if past_key_values:
780
  past_length = past_key_values[0][0].shape[2]
781
  if past_length:
782
  full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
783
+ device=device), full_attention_mask), dim=-1)
784
  if padding_mask is not None:
785
  full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
786
  if not past_length and padding_mask is not None:
 
873
  use_cache = use_cache if use_cache is not None else self.config.use_cache
874
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
 
876
+ batch_size, seq_length = (input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] if inputs_embeds is not None else (None, None))
877
 
878
  if inputs_embeds is None:
879
  inputs_embeds = self.embedding(input_ids)