zxdu20 commited on
Commit
0564795
1 Parent(s): 2200e2b
Files changed (1) hide show
  1. modeling_chatglm.py +1 -2
modeling_chatglm.py CHANGED
@@ -817,7 +817,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
817
  # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
818
  return past_key_values
819
 
820
- @staticmethod
821
  def get_masks(self, input_ids, device):
822
  batch_size, seq_length = input_ids.shape
823
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
@@ -900,7 +899,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
900
  )
901
 
902
  if self.pre_seq_len is not None:
903
- prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
904
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
905
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
906
 
 
817
  # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
818
  return past_key_values
819
 
 
820
  def get_masks(self, input_ids, device):
821
  batch_size, seq_length = input_ids.shape
822
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
 
899
  )
900
 
901
  if self.pre_seq_len is not None:
902
+ prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
903
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
904
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
905