|
import transformers |
|
import torch |
|
from typing import Optional, Tuple, Union |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor |
|
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE |
|
|
|
|
|
class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration): |
|
|
|
def forward( |
|
self, |
|
input_features: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
decoder_head_mask: Optional[torch.Tensor] = None, |
|
cross_attn_head_mask: Optional[torch.Tensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
forced_ac_decoder_ids: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: |
|
return super().forward( |
|
input_features=input_features, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
head_mask=head_mask, |
|
decoder_head_mask=decoder_head_mask, |
|
cross_attn_head_mask=cross_attn_head_mask, |
|
encoder_outputs=encoder_outputs, |
|
past_key_values=past_key_values, |
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
forced_ac_decoder_ids: Optional[torch.Tensor] = None, |
|
generation_config=None, |
|
logits_processor=None, |
|
stopping_criteria=None, |
|
prefix_allowed_tokens_fn=None, |
|
synced_gpus=False, |
|
return_timestamps=None, |
|
task="transcribe", |
|
language="english", |
|
**kwargs, |
|
): |
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
|
|
if return_timestamps is not None: |
|
if not hasattr(generation_config, "no_timestamps_token_id"): |
|
raise ValueError( |
|
"You are trying to return timestamps, but the generation config is not properly set." |
|
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`." |
|
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" |
|
) |
|
|
|
generation_config.return_timestamps = return_timestamps |
|
else: |
|
generation_config.return_timestamps = False |
|
|
|
if language is not None: |
|
generation_config.language = language |
|
if task is not None: |
|
generation_config.task = task |
|
|
|
forced_decoder_ids = [] |
|
if task is not None or language is not None: |
|
if hasattr(generation_config, "language"): |
|
if generation_config.language in generation_config.lang_to_id.keys(): |
|
language_token = generation_config.language |
|
elif generation_config.language in TO_LANGUAGE_CODE.keys(): |
|
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" |
|
else: |
|
raise ValueError( |
|
f"Unsupported language: {language}. Language should be one of:" |
|
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." |
|
) |
|
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) |
|
else: |
|
forced_decoder_ids.append((1, None)) |
|
|
|
if hasattr(generation_config, "task"): |
|
if generation_config.task in TASK_IDS: |
|
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) |
|
else: |
|
raise ValueError( |
|
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" |
|
) |
|
else: |
|
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) |
|
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: |
|
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 |
|
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) |
|
|
|
|
|
elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: |
|
forced_decoder_ids = self.config.forced_decoder_ids |
|
elif ( |
|
hasattr(self.generation_config, "forced_decoder_ids") |
|
and self.generation_config.forced_decoder_ids is not None |
|
): |
|
forced_decoder_ids = self.generation_config.forced_decoder_ids |
|
|
|
if generation_config.return_timestamps: |
|
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] |
|
|
|
decoder_input_ids = None |
|
|
|
if len(forced_decoder_ids) > 0: |
|
|
|
forced_decoder_ids.sort() |
|
if min(forced_decoder_ids)[0] != 0: |
|
forced_decoder_ids = [(0, self.config.decoder_start_token_id)] + forced_decoder_ids |
|
|
|
position_indices, decoder_input_ids = zip(*forced_decoder_ids) |
|
assert tuple(position_indices) == tuple(range(len(position_indices))), "forced_decoder_ids is not a (continuous) prefix, we can't handle that" |
|
|
|
device = self.get_decoder().device |
|
|
|
if forced_ac_decoder_ids is None: |
|
forced_ac_decoder_ids = torch.tensor([[]], device=device, dtype=torch.long) |
|
|
|
|
|
batch_size = forced_ac_decoder_ids.shape[0] |
|
fluff_len = len(decoder_input_ids) |
|
decoder_input_ids = torch.tensor(decoder_input_ids, device=device, dtype=torch.long) |
|
decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len)) |
|
decoder_input_ids = torch.cat([decoder_input_ids, forced_ac_decoder_ids], dim=1) |
|
|
|
generation_config.forced_decoder_ids = forced_decoder_ids |
|
|
|
return super(transformers.WhisperPreTrainedModel, self).generate( |
|
inputs, |
|
generation_config, |
|
logits_processor, |
|
stopping_criteria, |
|
prefix_allowed_tokens_fn, |
|
synced_gpus, |
|
decoder_input_ids=decoder_input_ids, |
|
**kwargs, |
|
) |
|
|