from typing import Optional, Union, Tuple, List import torch from transformers import VisionEncoderDecoderModel from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): def forward( self, pixel_values: Optional[torch.FloatTensor] = None, decoder_input_boxes: torch.LongTensor = None, # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[List[List[int]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_outputs is None: if pixel_values is None: raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) elif isinstance(encoder_outputs, tuple): encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # else: encoder_attention_mask = None # Decode decoder_outputs = self.decoder( input_boxes=decoder_input_boxes, input_boxes_mask=decoder_input_boxes_mask, input_boxes_counts=decoder_input_boxes_counts, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, labels=labels, **kwargs_decoder, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )