davda54 commited on
Commit
502858a
1 Parent(s): 900d7b0
Files changed (1) hide show
  1. modeling_nort5.py +1 -1
modeling_nort5.py CHANGED
@@ -405,7 +405,7 @@ class NorT5Model(NorT5PreTrainedModel):
405
  def get_decoder_output(
406
  self, target_ids, encoder_output, attention_mask
407
  ):
408
- batch_size, seq_length = target_ids.shape
409
  device = target_ids.device
410
 
411
  if attention_mask is None:
 
405
  def get_decoder_output(
406
  self, target_ids, encoder_output, attention_mask
407
  ):
408
+ batch_size, seq_length, _ = encoder_output.shape
409
  device = target_ids.device
410
 
411
  if attention_mask is None: