############################# # Imports ############################# # Python modules from typing import ( Optional, Tuple, Union, List, ) # Remote modules import torch from torch import nn from transformers import ( BartConfig, BartPretrainedModel, ) from transformers.modeling_outputs import ( BaseModelOutput, Seq2SeqModelOutput, ) from transformers.models.bart.modeling_bart import shift_tokens_right from transformers.utils import ( add_code_sample_docstrings, add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) # Local modules from .config import BartCustomConfig from .encoder import BartCustomEncoder from .decoder import BartCustomDecoder from .custom_constants import BartConstants from .custom_outputs import CustomSeq2SeqModelOutput @add_start_docstrings( "The bare BART Model outputting raw hidden-states without any specific head on top.", BartConstants.BART_START_DOCSTRING, ) class BartCustomModel(BartPretrainedModel): def __init__(self, config: BartCustomConfig): super().__init__(config) padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) self.encoder = BartCustomEncoder(config, self.shared) self.decoder = BartCustomDecoder(config, self.shared) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder @add_start_docstrings_to_model_forward(BartConstants.BART_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class= BartConstants.TOKENIZER_FOR_DOC, checkpoint= BartConstants.CHECKPOINT_FOR_DOC, output_type= Seq2SeqModelOutput, config_class= BartConstants.CONFIG_FOR_DOC, expected_output= BartConstants.EXPECTED_OUTPUT_SHAPE, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, relation_inputs: Optional[torch.Tensor] = None, ) -> Union[Tuple, CustomSeq2SeqModelOutput]: # different to other models, Bart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided if decoder_input_ids is None and decoder_inputs_embeds is None: if input_ids is None: raise ValueError( "If no `decoder_input_ids` or `decoder_inputs_embeds` are " "passed, `input_ids` cannot be `None`. Please pass either " "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." ) decoder_input_ids = shift_tokens_right( input_ids, self.config.pad_token_id, self.config.decoder_start_token_id ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, relation_inputs=relation_inputs ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return CustomSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, encoder_head_mask=head_mask )