zxdu20 commited on
Commit
35ca523
1 Parent(s): 0829959

Fix input embeds

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -3
modeling_chatglm.py CHANGED
@@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
918
  elif input_ids is not None:
919
  batch_size, seq_length = input_ids.shape[:2]
920
  elif inputs_embeds is not None:
921
- batch_size, seq_length, _ = inputs_embeds.shape[:2]
922
  else:
923
  raise ValueError("You have to specify either input_ids or inputs_embeds")
924
 
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
972
 
973
  if attention_mask is None:
974
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
975
-
976
  else:
977
- attention_mask = attention_mask.to(input_ids.device)
978
 
979
  for i, layer in enumerate(self.layers):
980
 
918
  elif input_ids is not None:
919
  batch_size, seq_length = input_ids.shape[:2]
920
  elif inputs_embeds is not None:
921
+ batch_size, seq_length = inputs_embeds.shape[:2]
922
  else:
923
  raise ValueError("You have to specify either input_ids or inputs_embeds")
924
 
972
 
973
  if attention_mask is None:
974
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
 
975
  else:
976
+ attention_mask = attention_mask.to(hidden_states.device)
977
 
978
  for i, layer in enumerate(self.layers):
979