fix left padding batch infer

#9
Files changed (1) hide show
  1. modeling_baichuan.py +7 -4
modeling_baichuan.py CHANGED
@@ -358,10 +358,13 @@ class BaichuanModel(BaichuanPreTrainedModel):
358
  expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
359
  inputs_embeds.device
360
  )
361
- combined_attention_mask = (
362
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
363
- )
364
-
 
 
 
365
  return combined_attention_mask
366
 
367
  def forward(
 
358
  expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
359
  inputs_embeds.device
360
  )
361
+ if combined_attention_mask is None:
362
+ combined_attention_mask = expanded_attn_mask
363
+ else:
364
+ expanded_attn_mask = torch.where(expanded_attn_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
365
+ combined_attention_mask = torch.where(combined_attention_mask == torch.finfo(inputs_embeds.dtype).min, torch.finfo(inputs_embeds.dtype).min / 2, expanded_attn_mask)
366
+ combined_attention_mask = expanded_attn_mask + combined_attention_mask
367
+
368
  return combined_attention_mask
369
 
370
  def forward(