cArlIcon commited on
Commit
7666d66
1 Parent(s): 0748797

update modeling_yi.py

Browse files
Files changed (1) hide show
  1. modeling_yi.py +1 -1
modeling_yi.py CHANGED
@@ -539,7 +539,7 @@ class YiModel(YiPreTrainedModel):
539
  def _prepare_decoder_attention_mask(
540
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
541
  ):
542
- input_shape = input_ids.shape
543
  # create causal mask
544
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
545
  combined_attention_mask = None
 
539
  def _prepare_decoder_attention_mask(
540
  self, attention_mask, input_ids, inputs_embeds, past_key_values_length
541
  ):
542
+ input_shape = input_ids.shape if input_ids else inputs_embeds.shape[:-1]
543
  # create causal mask
544
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
545
  combined_attention_mask = None