zxdu20 commited on
Commit
774fda8
1 Parent(s): 6ed40fb

Fix batch beam search

Browse files
Files changed (1) hide show
  1. modeling_glm.py +10 -0
modeling_glm.py CHANGED
@@ -873,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
873
  position_ids = position_ids[:, :, :seq_length]
874
  if attention_mask is not None:
875
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
 
 
 
 
 
 
 
 
 
 
876
  return {
877
  "input_ids": input_ids,
878
  "position_ids": position_ids,
 
873
  position_ids = position_ids[:, :, :seq_length]
874
  if attention_mask is not None:
875
  attention_mask = attention_mask[:, :, :seq_length, :seq_length]
876
+ if position_ids is not None and input_ids.size(0) > position_ids.size(0):
877
+ batch_size = position_ids.size(0)
878
+ num_beams = input_ids.size(0) // batch_size
879
+ position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
880
+ position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
881
+ if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
882
+ batch_size = attention_mask.size(0)
883
+ num_beams = input_ids.size(0) // batch_size
884
+ attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
885
+ attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
886
  return {
887
  "input_ids": input_ids,
888
  "position_ids": position_ids,