prompteus geekweilai commited on
Commit
53d0b79
1 Parent(s): e660e78

Update model_class.py (#5)

Browse files

- Update model_class.py (c0a6c8d4ea3905485ea20f4f16f6a87c09538893)


Co-authored-by: weilai <geekweilai@users.noreply.huggingface.co>

Files changed (1) hide show
  1. model_class.py +2 -0
model_class.py CHANGED
@@ -26,6 +26,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
26
  output_hidden_states: Optional[bool] = None,
27
  return_dict: Optional[bool] = None,
28
  forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
 
29
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
30
  return super().forward(
31
  input_features=input_features,
@@ -43,6 +44,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
43
  output_attentions=output_attentions,
44
  output_hidden_states=output_hidden_states,
45
  return_dict=return_dict,
 
46
  )
47
 
48
  # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
 
26
  output_hidden_states: Optional[bool] = None,
27
  return_dict: Optional[bool] = None,
28
  forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
29
+ decoder_position_ids: Optional[torch.LongTensor] = None,
30
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
31
  return super().forward(
32
  input_features=input_features,
 
44
  output_attentions=output_attentions,
45
  output_hidden_states=output_hidden_states,
46
  return_dict=return_dict,
47
+ decoder_position_ids=decoder_position_ids,
48
  )
49
 
50
  # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate