duzx16 commited on
Commit
c57e892
1 Parent(s): fc442f7

Fix prefix prompt in evaluation

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +8 -5
modeling_chatglm.py CHANGED
@@ -803,6 +803,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
803
  if inputs_embeds is None:
804
  inputs_embeds = self.embedding(input_ids)
805
 
 
 
 
 
 
 
 
 
806
  if full_attention_mask is None:
807
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
808
  full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
@@ -815,11 +823,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
815
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
816
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
817
 
818
- if past_key_values is None:
819
- if self.pre_seq_len is not None:
820
- past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
821
- dtype=inputs_embeds.dtype)
822
-
823
  # Run encoder.
824
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
825
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
 
803
  if inputs_embeds is None:
804
  inputs_embeds = self.embedding(input_ids)
805
 
806
+ if self.pre_seq_len is not None:
807
+ if past_key_values is None:
808
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
809
+ dtype=inputs_embeds.dtype)
810
+ if attention_mask is not None:
811
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
812
+ attention_mask], dim=-1)
813
+
814
  if full_attention_mask is None:
815
  if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
816
  full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
 
823
  rotary_pos_emb = rotary_pos_emb[None, :seq_length]
824
  rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
825
 
 
 
 
 
 
826
  # Run encoder.
827
  hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
828
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,