|
from typing import Optional, Union, Tuple, List |
|
|
|
import torch |
|
from transformers import VisionEncoderDecoderModel |
|
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput |
|
|
|
|
|
class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): |
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
decoder_input_boxes: torch.LongTensor = None, |
|
|
|
decoder_input_boxes_mask: torch.LongTensor = None, |
|
decoder_input_boxes_counts: torch.LongTensor = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[List[List[int]]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None: |
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_encoder, |
|
) |
|
elif isinstance(encoder_outputs, tuple): |
|
encoder_outputs = BaseModelOutput(*encoder_outputs) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
|
|
|
encoder_attention_mask = None |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_boxes=decoder_input_boxes, |
|
input_boxes_mask=decoder_input_boxes_mask, |
|
input_boxes_counts=decoder_input_boxes_counts, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
labels=labels, |
|
**kwargs_decoder, |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=decoder_outputs.loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|