Fix: AttributeError when `input_ids` is None during multimodal LLM training
#77
by
lyulumos
- opened
- modeling_chatglm.py +5 -4
modeling_chatglm.py
CHANGED
@@ -771,15 +771,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
771 |
if padding_mask is not None and not padding_mask.all():
|
772 |
return padding_mask
|
773 |
return None
|
774 |
-
batch_size, seq_length = input_ids.shape
|
775 |
-
|
|
|
776 |
full_attention_mask.tril_()
|
777 |
past_length = 0
|
778 |
if past_key_values:
|
779 |
past_length = past_key_values[0][0].shape[2]
|
780 |
if past_length:
|
781 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
782 |
-
device=
|
783 |
if padding_mask is not None:
|
784 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
785 |
if not past_length and padding_mask is not None:
|
@@ -872,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
872 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
873 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
874 |
|
875 |
-
batch_size, seq_length = input_ids.shape
|
876 |
|
877 |
if inputs_embeds is None:
|
878 |
inputs_embeds = self.embedding(input_ids)
|
|
|
771 |
if padding_mask is not None and not padding_mask.all():
|
772 |
return padding_mask
|
773 |
return None
|
774 |
+
batch_size, seq_length = input_ids.shape if input_ids is not None else padding_mask.shape
|
775 |
+
device = input_ids.device if input_ids is not None else padding_mask.device
|
776 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=device)
|
777 |
full_attention_mask.tril_()
|
778 |
past_length = 0
|
779 |
if past_key_values:
|
780 |
past_length = past_key_values[0][0].shape[2]
|
781 |
if past_length:
|
782 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
783 |
+
device=device), full_attention_mask), dim=-1)
|
784 |
if padding_mask is not None:
|
785 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
786 |
if not past_length and padding_mask is not None:
|
|
|
873 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
874 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
875 |
|
876 |
+
batch_size, seq_length = (input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] if inputs_embeds is not None else (None, None))
|
877 |
|
878 |
if inputs_embeds is None:
|
879 |
inputs_embeds = self.embedding(input_ids)
|