prompteus commited on
Commit
a75f70c
1 Parent(s): e8e231a

Update model_class.py

Browse files
Files changed (1) hide show
  1. model_class.py +3 -0
model_class.py CHANGED
@@ -26,6 +26,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
26
  output_hidden_states: Optional[bool] = None,
27
  return_dict: Optional[bool] = None,
28
  forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
 
29
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
30
  return super().forward(
31
  input_features=input_features,
@@ -43,6 +44,7 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
43
  output_attentions=output_attentions,
44
  output_hidden_states=output_hidden_states,
45
  return_dict=return_dict,
 
46
  )
47
 
48
  # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
@@ -156,3 +158,4 @@ class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):
156
  decoder_input_ids=decoder_input_ids,
157
  **kwargs,
158
  )
 
 
26
  output_hidden_states: Optional[bool] = None,
27
  return_dict: Optional[bool] = None,
28
  forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
29
+ decoder_position_ids: Optional[torch.LongTensor] = None,
30
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
31
  return super().forward(
32
  input_features=input_features,
 
44
  output_attentions=output_attentions,
45
  output_hidden_states=output_hidden_states,
46
  return_dict=return_dict,
47
+ decoder_position_ids=decoder_position_ids,
48
  )
49
 
50
  # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
 
158
  decoder_input_ids=decoder_input_ids,
159
  **kwargs,
160
  )
161
+