zxdu20 commited on
Commit
08bc851
1 Parent(s): 4b7ffbf

Fix attention mask for prefix prompt

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +6 -5
modeling_chatglm.py CHANGED
@@ -919,11 +919,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
919
  device=input_ids.device
920
  )
921
 
922
- if self.pre_seq_len is not None:
923
- prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
924
- attention_mask.device)
925
- prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
926
- attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
927
 
928
  if position_ids is None:
929
  MASK, gMASK = 150000, 150001
@@ -938,6 +933,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
938
  gmask=use_gmask
939
  )
940
 
 
 
 
 
 
 
941
  # [seq_len, batch, hidden_size]
942
  hidden_states = inputs_embeds.transpose(0, 1)
943
 
 
919
  device=input_ids.device
920
  )
921
 
 
 
 
 
 
922
 
923
  if position_ids is None:
924
  MASK, gMASK = 150000, 150001
 
933
  gmask=use_gmask
934
  )
935
 
936
+ if self.pre_seq_len is not None and attention_mask is not None:
937
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
938
+ attention_mask.device)
939
+ prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
940
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
941
+
942
  # [seq_len, batch, hidden_size]
943
  hidden_states = inputs_embeds.transpose(0, 1)
944