# Copyright 2022 san kim # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import math import os import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union, Callable import torch from torch import nn from torch.nn import CrossEntropyLoss from torch.utils.checkpoint import checkpoint try: from torch.nn import Identity except ImportError: # Older PyTorch compatibility class Identity(nn.Module): r"""A placeholder identity operator that is argument-insensitive.""" def __init__(self, *args, **kwargs): super().__init__() def forward(self, input): return input from transformers.models.t5.modeling_t5 import ( T5LayerSelfAttention, T5LayerCrossAttention, T5LayerFF, T5PreTrainedModel, T5LayerNorm, PARALLELIZE_DOCSTRING, DEPARALLELIZE_DOCSTRING, __HEAD_MASK_WARNING_MSG, T5_START_DOCSTRING, T5_INPUTS_DOCSTRING ) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, BaseModelOutput ) from transformers.utils import ( DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, logging, replace_return_docstrings, ModelOutput, ) from transformers.utils.model_parallel_utils import assert_device_map, get_device_map from transformers import T5Config from transformers.configuration_utils import PretrainedConfig from transformers.activations import get_activation logger = logging.get_logger(__name__) _CONFIG_FOR_DOC_DDT5 = "T5Config" def get_last_token_index(mask): # attention masks: [batch_size, seq] batch_size, seq_length = mask.shape[:2] incr = torch.arange(seq_length, device=mask.device, requires_grad=False) incr_m = torch.einsum("i,ji->ji", incr, mask) return torch.argmax(incr_m, dim=1) # modified from huggingface transformers lib (add attn pooling) class SequenceSummary(nn.Module): r""" Compute a single vector summary of a sequence hidden states. Args: config ([`PretrainedConfig`]): The config used by the model. Relevant arguments in the config class of the model are (refer to the actual config class of your model for the default values it uses): - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: - `"last"` -- Take the last token hidden state (like XLNet) - `"first"` -- Take the first token hidden state (like Bert) - `"mean"` -- Take the mean of all tokens hidden states - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) - `"attn"` -- Not implemented now, use multi-head attention - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes (otherwise to `config.hidden_size`). - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, another string or `None` will add no activation. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. """ def __init__(self, config: PretrainedConfig, num_queries=1): super().__init__() self.summary_type = getattr(config, "summary_type", "last") if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 self.queries = nn.Parameter(torch.empty(num_queries, config.hidden_size)) nn.init.kaiming_uniform_(self.queries, a=math.sqrt(5)) self.MultiheadAttention = nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, batch_first=True ) layer_norm_eps = getattr(config, "layer_norm_eps", 1e-6) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) self.summary = Identity() if hasattr(config, "summary_use_proj") and config.summary_use_proj: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = nn.Linear(config.hidden_size, num_classes) activation_string = getattr(config, "summary_activation", None) self.activation: Callable = get_activation(activation_string) if activation_string else Identity() self.first_dropout = Identity() if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: self.first_dropout = nn.Dropout(config.summary_first_dropout) self.last_dropout = Identity() if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: self.last_dropout = nn.Dropout(config.summary_last_dropout) def forward( self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None ) -> torch.FloatTensor: """ Compute a single vector summary of a sequence hidden states. Args: hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): The hidden states of the last layer. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. Returns: `torch.FloatTensor`: The summary of the sequence hidden states. """ if self.summary_type == "last": output = hidden_states[:, -1] elif self.summary_type == "first": output = hidden_states[:, 0] elif self.summary_type == "mean": output = hidden_states.mean(dim=1) elif self.summary_type == "cls_index": if cls_index is None: cls_index = torch.full_like( hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long, ) else: cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) elif self.summary_type == "attn": batch_size = hidden_states.size(0) queries = self.queries.repeat(batch_size, 1, 1) output = self.MultiheadAttention(queries, hidden_states, hidden_states, need_weights=False)[0] output = self.LayerNorm(output) output = self.first_dropout(output) output = self.summary(output) output = self.activation(output) output = self.last_dropout(output) return output # add_cross_attention class T5DecoderBlock(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() self.is_decoder = config.is_decoder self.has_cross_attention = config.add_cross_attention self.layer = nn.ModuleList() self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) if self.has_cross_attention: self.layer.append(T5LayerCrossAttention(config)) self.layer.append(T5LayerFF(config)) def forward( self, hidden_states, attention_mask=None, position_bias=None, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, past_key_value=None, use_cache=False, output_attentions=False, return_dict=True, ): if past_key_value is not None: if not self.is_decoder: logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 if len(past_key_value) != expected_num_past_key_values: raise ValueError( f"There should be {expected_num_past_key_values} past states. " f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" f"Got {len(past_key_value)} past key / value states" ) self_attn_past_key_value = past_key_value[:2] cross_attn_past_key_value = past_key_value[2:] else: self_attn_past_key_value, cross_attn_past_key_value = None, None self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, past_key_value=self_attn_past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) do_cross_attention = self.has_cross_attention and encoder_hidden_states is not None if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here if present_key_value_state is not None: query_length = present_key_value_state[0].shape[2] else: query_length = None cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, query_length=query_length, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = cross_attention_outputs[0] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Combine self attn and cross attn key value states if present_key_value_state is not None: present_key_value_state = present_key_value_state + cross_attention_outputs[1] # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if use_cache: outputs = outputs + (present_key_value_state,) + attention_outputs else: outputs = outputs + attention_outputs return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) @dataclass class BaseModelOutputWithPastAndCrossAttentionsAndPositionBias(ModelOutput): """ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) plus position bias. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. position_bias (`torch.FloatTensor`, *optional*, returned when the model is self-attention decoder): position_bias is created in the first layer of the self-attention decoder, and it passes through all the layers including layers of the cross-attention decoder. `torch.FloatTensor` of shape `(batch_size, num_heads, sequence_length, sequence_length)`. """ last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None position_bias: Optional[torch.FloatTensor] = None class T5DecoderStack(T5PreTrainedModel): def __init__(self, config, embed_tokens=None, has_relative_attention_bias=True): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder self.has_cross_attention = config.add_cross_attention self.block = nn.ModuleList( [T5DecoderBlock(config, has_relative_attention_bias=bool(i == 0) and has_relative_attention_bias) for i in range(config.num_layers)] ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) # Initialize weights and apply final processing self.post_init() # Model parallel self.model_parallel = False self.device_map = None self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): # Check validity of device_map self.device_map = ( get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.block)) self.model_parallel = True self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) self.last_device = "cuda:" + str(max(self.device_map.keys())) # Load onto devices for k, v in self.device_map.items(): for layer in v: cuda_device = "cuda:" + str(k) self.block[layer] = self.block[layer].to(cuda_device) # Set embed_tokens to first layer self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens # Set final layer norm to last device self.final_layer_norm = self.final_layer_norm.to(self.last_device) @add_start_docstrings(PARALLELIZE_DOCSTRING) def deparallelize(self): self.model_parallel = False self.device_map = None self.first_device = "cpu" self.last_device = "cpu" for i in range(len(self.block)): self.block[i] = self.block[i].to("cpu") self.embed_tokens = self.embed_tokens.to("cpu") if self.embed_tokens is not None else self.embed_tokens self.final_layer_norm = self.final_layer_norm.to("cpu") torch.cuda.empty_cache() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, new_embeddings): self.embed_tokens = new_embeddings def forward( self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, inputs_embeds=None, position_bias=None, encoder_decoder_position_bias=None, head_mask=None, cross_attn_head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # Model parallel if self.model_parallel: torch.cuda.set_device(self.first_device) self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError( f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape # required mask seq length can be calculated via length of past mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length if use_cache is True: assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if self.has_cross_attention and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = encoder_hidden_states.shape[1] encoder_attention_mask = torch.ones( batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long ) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.has_cross_attention and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.has_cross_attention) else None hidden_states = self.dropout(inputs_embeds) for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) if encoder_extended_attention_mask is not None: encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) if encoder_decoder_position_bias is not None: encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) if layer_head_mask is not None: layer_head_mask = layer_head_mask.to(hidden_states.device) if cross_attn_layer_head_mask is not None: cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return tuple(module(*inputs, use_cache, output_attentions)) return custom_forward layer_outputs = checkpoint( create_custom_forward(layer_module), hidden_states, extended_attention_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, encoder_decoder_position_bias, layer_head_mask, cross_attn_layer_head_mask, None, # past_key_value is always None with gradient checkpointing ) else: layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # (cross-attention position bias), (cross-attention weights) position_bias = layer_outputs[2] if self.has_cross_attention and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) if self.has_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[5],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: outputs = tuple( v for v in [ hidden_states, present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None ) outputs = outputs + (position_bias,) return outputs return BaseModelOutputWithPastAndCrossAttentionsAndPositionBias( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, position_bias=position_bias ) @dataclass class DualDecoderModelOutput(ModelOutput): """ Base class for model dual decoder's outputs that also contains : pre-computed hidden states that can speed up sequential decoding. Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the decoder of the model. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the cross-attention decoder at the output of each layer. cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the cross-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class DualDecoderLMOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the cross-attention decoder at the output of each layer. cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the cross-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class DualDecoderDoubleHeadsOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). ss_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Global representaion of the self-attention decoder. The last token of sequence is used to calculate this representation. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the cross-attention decoder at the output of each layer. cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the cross-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None ss_logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) class T5DualDecoderLMHeadModel(T5PreTrainedModel): def __init__(self, config: T5Config, add_pooling_layer: bool = True): config.is_encoder_decoder = False config.is_decoder = True super().__init__(config) self.model_dim = config.d_model self.shared = nn.Embedding(config.vocab_size, config.d_model) self_decoder_config = copy.deepcopy(config) self_decoder_config.is_decoder = True self_decoder_config.is_encoder_decoder = False self_decoder_config.add_cross_attention = False # self.self_decoder = T5DecoderStack(self_decoder_config, self.shared) self.encoder = T5DecoderStack(self_decoder_config, self.shared) cross_decoder_config = copy.deepcopy(config) cross_decoder_config.is_decoder = True cross_decoder_config.is_encoder_decoder = False cross_decoder_config.add_cross_attention = True cross_decoder_config.num_layers = config.num_decoder_layers # self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False) self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() # Model parallel self.model_parallel = False self.device_map = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = ( # get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) if device_map is None else device_map ) # assert_device_map(self.device_map, len(self.self_decoder.block)) assert_device_map(self.device_map, len(self.encoder.block)) # self.self_decoder.parallelize(self.device_map) # self.cross_decoder.parallelize(self.device_map) # self.lm_head = self.lm_head.to(self.cross_decoder.first_device) self.encoder.parallelize(self.device_map) self.decoder.parallelize(self.device_map) self.lm_head = self.lm_head.to(self.decoder.first_device) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): # self.self_decoder.deparallelize() # self.cross_decoder.deparallelize() # self.self_decoder = self.self_decoder.to("cpu") # self.cross_decoder = self.cross_decoder.to("cpu") self.encoder.deparallelize() self.decoder.deparallelize() self.encoder = self.encoder.to("cpu") self.decoder = self.decoder.to("cpu") self.lm_head = self.lm_head.to("cpu") self.model_parallel = False self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings # self.self_decoder.set_input_embeddings(new_embeddings) # self.cross_decoder.set_input_embeddings(new_embeddings) self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_output_embeddings(self): return self.lm_head def get_encoder(self): # return self.self_decoder return self.encoder def get_decoder(self): # return self.cross_decoder return self.decoder @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=DualDecoderLMOutput, config_class=_CONFIG_FOR_DOC_DDT5) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, # encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], DualDecoderLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` Returns: Examples: ```python >>> from transformers import T5Tokenizer, T5DualDecoderLMHeadModel >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> # inference >>> input_ids = tokenizer( ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 >>> outputs = model.generate(input_ids) >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> # studies have shown that owning a dog is good for you. ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: if self.config.num_layers == self.config.num_decoder_layers: warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask if past_key_values is not None: self_decoder_past_key_value = past_key_values[0] cross_decoder_past_key_value = past_key_values[1] else: self_decoder_past_key_value, cross_decoder_past_key_value = None, None if labels is not None and input_ids is None and inputs_embeds is None: # get decoder inputs from shifting lm labels to the right input_ids = self._shift_right(labels) # self attention decoder # self_decoder_outputs = self.self_decoder( self_decoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=self_decoder_past_key_value, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = self_decoder_outputs[0] position_bias = self_decoder_outputs[-1] # get encoder hidden states # encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] # encoder_attention_mask = None # Set device for model parallelism if self.model_parallel: # torch.cuda.set_device(self.cross_decoder.first_device) # hidden_states = hidden_states.to(self.cross_decoder.first_device) # if attention_mask is not None: # attention_mask = attention_mask.to(self.cross_decoder.first_device) torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) # cross attention decoder # cross_decoder_outputs = self.cross_decoder( cross_decoder_outputs = self.decoder( attention_mask=attention_mask, inputs_embeds=hidden_states, position_bias=position_bias, past_key_values=cross_decoder_past_key_value, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = cross_decoder_outputs[0] # Set device for model parallelism if self.model_parallel: # torch.cuda.set_device(self.self_decoder.first_device) # self.lm_head = self.lm_head.to(self.self_decoder.first_device) torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: past_key_values = None else: past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) if not return_dict: output = (lm_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] return ((loss,) + output) if loss is not None else output return DualDecoderLMOutput( loss=loss, logits=lm_logits, past_key_values=past_key_values, cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, cross_decoder_attentions=cross_decoder_outputs.attentions, cross_attentions=cross_decoder_outputs.cross_attentions, self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, self_decoder_hidden_states=self_decoder_outputs.hidden_states, self_decoder_attentions=self_decoder_outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, # encoder_outputs=None, encoder_hidden_states=None, encoder_attention_mask=None, **kwargs ): # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past, # "encoder_outputs": encoder_outputs, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) def _reorder_cache(self, past, beam_idx): if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) def _reorder_cache_single(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past reordered_decoder_past = () for layer_past_states in past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + ( layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), ) assert reordered_layer_past_states[0].shape == layer_past_states[0].shape assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return reordered_decoder_past @add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) class T5DualDecoderDoubleHeadsModel(T5PreTrainedModel): def __init__(self, config: T5Config, add_pooling_layer: bool = True): config.is_encoder_decoder = False config.is_decoder = True super().__init__(config) self.model_dim = config.d_model self.shared = nn.Embedding(config.vocab_size, config.d_model) self_decoder_config = copy.deepcopy(config) self_decoder_config.is_decoder = True self_decoder_config.is_encoder_decoder = False self_decoder_config.add_cross_attention = False # self.self_decoder = T5DecoderStack(self_decoder_config, self.shared) self.encoder = T5DecoderStack(self_decoder_config, self.shared) cross_decoder_config = copy.deepcopy(config) cross_decoder_config.is_decoder = True cross_decoder_config.is_encoder_decoder = False cross_decoder_config.add_cross_attention = True cross_decoder_config.num_layers = config.num_decoder_layers # self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False) self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) sequence_summary_config = copy.deepcopy(config) sequence_summary_config.summary_type = "cls_index" self.ss_head = SequenceSummary(config) # Initialize weights and apply final processing self.post_init() # Model parallel self.model_parallel = False self.device_map = None @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): self.device_map = ( # get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count())) get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) if device_map is None else device_map ) # assert_device_map(self.device_map, len(self.self_decoder.block)) assert_device_map(self.device_map, len(self.encoder.block)) # self.self_decoder.parallelize(self.device_map) # self.cross_decoder.parallelize(self.device_map) # self.lm_head = self.lm_head.to(self.cross_decoder.first_device) # self.ss_head = self.ss_head.to(self.cross_decoder.first_device) self.encoder.parallelize(self.device_map) self.decoder.parallelize(self.device_map) self.lm_head = self.lm_head.to(self.decoder.first_device) self.ss_head = self.ss_head.to(self.decoder.first_device) self.model_parallel = True @add_start_docstrings(DEPARALLELIZE_DOCSTRING) def deparallelize(self): # self.self_decoder.deparallelize() # self.cross_decoder.deparallelize() # self.self_decoder = self.self_decoder.to("cpu") # self.cross_decoder = self.cross_decoder.to("cpu") self.encoder.deparallelize() self.decoder.deparallelize() self.encoder = self.encoder.to("cpu") self.decoder = self.decoder.to("cpu") self.lm_head = self.lm_head.to("cpu") self.ss_head = self.ss_head.to("cpu") self.model_parallel = False self.device_map = None torch.cuda.empty_cache() def get_input_embeddings(self): return self.shared def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings # self.self_decoder.set_input_embeddings(new_embeddings) # self.cross_decoder.set_input_embeddings(new_embeddings) self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def get_output_embeddings(self): return self.lm_head def get_encoder(self): # return self.self_decoder return self.encoder def get_decoder(self): # return self.cross_decoder return self.decoder @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=DualDecoderDoubleHeadsOutput, config_class=_CONFIG_FOR_DOC_DDT5) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, decoder_head_mask: Optional[torch.FloatTensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, # encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], DualDecoderDoubleHeadsOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` Returns: Examples: ```python >>> from transformers import T5Tokenizer, T5DualDecoderDoubleHeadsModel >>> tokenizer = T5Tokenizer.from_pretrained("veld-t5-base") >>> model = T5DualDecoderDoubleHeadsModel.from_pretrained("veld-t5-base") >>> # training >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids >>> outputs = model(input_ids=input_ids, labels=labels) >>> loss = outputs.loss >>> logits = outputs.logits >>> # inference >>> input_ids = tokenizer( ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" ... ).input_ids # Batch size 1 >>> outputs = model.generate(input_ids) >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) >>> # studies have shown that owning a dog is good for you. ```""" use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: if self.config.num_layers == self.config.num_decoder_layers: warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) decoder_head_mask = head_mask if past_key_values is not None: self_decoder_past_key_value = past_key_values[0] cross_decoder_past_key_value = past_key_values[1] else: self_decoder_past_key_value, cross_decoder_past_key_value = None, None if labels is not None and input_ids is None and inputs_embeds is None: # get decoder inputs from shifting lm labels to the right input_ids = self._shift_right(labels) # self attention decoder # self_decoder_outputs = self.self_decoder( self_decoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=self_decoder_past_key_value, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = self_decoder_outputs[0] position_bias = self_decoder_outputs[-1] # get encoder hidden states # encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] # encoder_attention_mask = None # Set device for model parallelism if self.model_parallel: # torch.cuda.set_device(self.cross_decoder.first_device) # hidden_states = hidden_states.to(self.cross_decoder.first_device) # if attention_mask is not None: # attention_mask = attention_mask.to(self.cross_decoder.first_device) torch.cuda.set_device(self.decoder.first_device) hidden_states = hidden_states.to(self.decoder.first_device) if attention_mask is not None: attention_mask = attention_mask.to(self.decoder.first_device) # cross attention decoder # cross_decoder_outputs = self.cross_decoder( cross_decoder_outputs = self.decoder( attention_mask=attention_mask, inputs_embeds=hidden_states, position_bias=position_bias, past_key_values=cross_decoder_past_key_value, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = cross_decoder_outputs[0] # Set device for model parallelism if self.model_parallel: # torch.cuda.set_device(self.self_decoder.first_device) # self.lm_head = self.lm_head.to(self.self_decoder.first_device) torch.cuda.set_device(self.encoder.first_device) self.lm_head = self.lm_head.to(self.encoder.first_device) sequence_output = sequence_output.to(self.lm_head.weight.device) if self.config.tie_word_embeddings: # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) lm_logits = self.lm_head(sequence_output) # cls_index = None if attention_mask is None else get_last_token_index(attention_mask) if self.config.pad_token_id is None: cls_index = None else: if input_ids is not None: cls_index = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 else: cls_index = None logger.warning( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) ss_logits = self.ss_head(hidden_states, cls_index=cls_index) loss = None if labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: past_key_values = None else: past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) if not return_dict: output = (lm_logits, ss_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] return ((loss,) + output) if loss is not None else output return DualDecoderDoubleHeadsOutput( loss=loss, logits=lm_logits, ss_logits=ss_logits, past_key_values=past_key_values, cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, cross_decoder_attentions=cross_decoder_outputs.attentions, cross_attentions=cross_decoder_outputs.cross_attentions, self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, self_decoder_hidden_states=self_decoder_outputs.hidden_states, self_decoder_attentions=self_decoder_outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, # encoder_outputs=None, encoder_hidden_states=None, encoder_attention_mask=None, **kwargs ): # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past, # "encoder_outputs": encoder_outputs, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": encoder_attention_mask, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) def _reorder_cache(self, past, beam_idx): if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) def _reorder_cache_single(self, past, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder if past is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") return past reordered_decoder_past = () for layer_past_states in past: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () for layer_past_state in layer_past_states: # need to set correct `past` for each of the four key / value states reordered_layer_past_states = reordered_layer_past_states + ( layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), ) assert reordered_layer_past_states[0].shape == layer_past_states[0].shape assert len(reordered_layer_past_states) == len(layer_past_states) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) return reordered_decoder_past from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( VISION_ENCODER_DECODER_START_DOCSTRING, VISION_ENCODER_DECODER_INPUTS_DOCSTRING, ) from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.modeling_auto import AutoModel from transformers import ViTModel, ViTConfig from .configuration_veld import VELDConfig _CONFIG_FOR_DOC_VELDT5 = "VELDConfig" @dataclass class VELDDoubleHeadsOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): Language modeling loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None c_loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None e_logits_g: torch.FloatTensor = None e_logits_l: torch.FloatTensor = None d_logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) class VELDModel(PreTrainedModel): r""" [`VELDModel`] is a generic model class that will be instantiated as a transformer architecture with one of the base vision model classes of the library as encoder and another one as dual decoder when created with the :meth*~transformers.AutoModel.from_pretrained* class method for the encoder. """ config_class = VELDConfig base_model_prefix = "veld" main_input_name = "pixel_values" supports_gradient_checkpointing = True def __init__( self, config: Optional[PretrainedConfig] = None, encoder: Optional[PreTrainedModel] = None, decoder: Optional[PreTrainedModel] = None, ): if config is None and (encoder is None or decoder is None): raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") if config is None: config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config) else: if not isinstance(config, self.config_class): raise ValueError(f"Config: {config} has to be of type {self.config_class}") if config.decoder.cross_attention_hidden_size is not None: if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: raise ValueError( "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" " `config.encoder.hidden_size`." ) # initialize with config # make sure input & output embeddings is not tied config.tie_word_embeddings = False super().__init__(config) if encoder is None: encoder = ViTModel(config.encoder, add_pooling_layer=False) if decoder is None: decoder = T5DualDecoderDoubleHeadsModel(config.decoder) self.encoder = encoder self.decoder = decoder if self.encoder.config.to_dict() != self.config.encoder.to_dict(): logger.warning( f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" f" {self.config.encoder}" ) if self.decoder.config.to_dict() != self.config.decoder.to_dict(): logger.warning( f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" f" {self.config.decoder}" ) # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder # encoder outputs might need to be projected to different dimension for decoder if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) if self.encoder.get_output_embeddings() is not None: raise ValueError( f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) pooling_config = copy.deepcopy(self.encoder.config) pooling_config.summary_type = "attn" self.global_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_global) self.local_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_local) def _set_gradient_checkpointing(self, module, value=False): # call both encoder and decoder function on gradient checkpointing self.encoder._set_gradient_checkpointing(module, value=value) self.decoder._set_gradient_checkpointing(module, value=value) def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def get_output_embeddings(self): return self.decoder.get_output_embeddings() def set_output_embeddings(self, new_embeddings): return self.decoder.set_output_embeddings(new_embeddings) @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported for composite models if kwargs.get("_fast_init", False): logger.warning( "Fast initialization is currently not supported for VELDModel. " "Falling back to slow initialization..." ) kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @classmethod def from_encoder_decoder_pretrained( cls, encoder_pretrained_model_name_or_path: str = None, decoder_pretrained_model_name_or_path: str = None, *model_args, **kwargs ) -> PreTrainedModel: r""" Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model checkpoints. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train the model, you need to first set it back in training mode with `model.train()`. Params: encoder_pretrained_model_name_or_path (`str`, *optional*): Information necessary to initiate the image encoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An example is `google/vit-base-patch16-224-in21k`. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, `from_tf` should be set to `True` and a configuration object should be provided as `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): Information necessary to initiate the text decoder. Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, `from_tf` should be set to `True` and a configuration object should be provided as `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. model_args (remaining positional arguments, *optional*): All remaning positional arguments will be passed to the underlying model's `__init__` method. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., `output_attentions=True`). - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. - To update the parent model configuration, do not use a prefix for each configuration parameter. Behaves differently depending on whether a `config` is provided or automatically loaded. Example: ```python >>> from modeling_veld import VELDModel >>> # initialize a vit-t5 from a pretrained ViT and a pretrained T5 model. Note that the cross-attention layers will be randomly initialized >>> model = VELDModel.from_encoder_decoder_pretrained( ... "google/vit-base-patch16-224-in21k", "t5-base" ... ) >>> # saving model after fine-tuning >>> model.save_pretrained("./vit-t5") >>> # load fine-tuned model >>> model = VELDModel.from_pretrained("./vit-t5") ```""" kwargs_encoder = { argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") } kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } # remove encoder, decoder kwargs from kwargs for key in kwargs_encoder.keys(): del kwargs["encoder_" + key] for key in kwargs_decoder.keys(): del kwargs["decoder_" + key] # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. encoder = kwargs_encoder.pop("model", None) if encoder is None: if encoder_pretrained_model_name_or_path is None: raise ValueError( "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_encoder: encoder_config, kwargs_encoder = ViTConfig.from_pretrained( encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True ) if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: logger.info( f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " "from a decoder model. Cross-attention and casual mask are disabled." ) encoder_config.is_decoder = False encoder_config.add_cross_attention = False kwargs_encoder["config"] = encoder_config encoder = ViTModel.from_pretrained(encoder_pretrained_model_name_or_path, add_pooling_layer=False, *model_args, **kwargs_encoder) decoder = kwargs_decoder.pop("model", None) if decoder is None: if decoder_pretrained_model_name_or_path is None: raise ValueError( "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " "to be defined." ) if "config" not in kwargs_decoder: decoder_config, kwargs_decoder = T5Config.from_pretrained( decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True ) if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: logger.info( f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." ) decoder_config.is_decoder = True decoder_config.add_cross_attention = True kwargs_decoder["config"] = decoder_config if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: logger.warning( f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" ) decoder = T5DualDecoderDoubleHeadsModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) # instantiate config with corresponding kwargs config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) # make sure input & output embeddings is not tied config.tie_word_embeddings = False return cls(encoder=encoder, decoder=decoder, config=config) @add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC_VELDT5) def forward( self, pixel_values=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, decoder_inputs_embeds=None, labels=None, return_contrastive_loss=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, logit_temperature=1.0, label_smoothing=0.0, **kwargs, ): r""" Returns: Examples: ```python >>> from transformers import AutoTokenizer, ViTFeatureExtractor, VELDModel >>> import requests >>> from PIL import Image >>> import torch >>> processor = ViTFeatureExtractor.from_pretrained("KETI-AIR/veld-base") >>> tokenizer = AutoTokenizer.from_pretrained("KETI-AIR/veld-base") >>> model = VELDModel.from_pretrained("KETI-AIR/veld-base") >>> # load image from the IAM dataset >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") >>> # training >>> pixel_values = processor(image, return_tensors="pt").pixel_values >>> text = "hello world" >>> labels = tokenizer(text, return_tensors="pt").input_ids >>> outputs = model(pixel_values=pixel_values, labels=labels) >>> loss = outputs.loss >>> # inference (generation) >>> generated_ids = model.generate(pixel_values, max_new_tokens=20) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_outputs is None and pixel_values is not None: # if pixel_values is None: # raise ValueError("You have to specify pixel_values") encoder_outputs = self.encoder( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) elif isinstance(encoder_outputs, tuple): encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] pooler_output_local = None if encoder_outputs is None else self.local_pooling(encoder_hidden_states) pooler_output_global = None if encoder_outputs is None or return_contrastive_loss is None else self.global_pooling(pooler_output_local).squeeze(1) # optionally project encoder_hidden_states if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None and pooler_output_local is not None ): pooler_output_local = self.enc_to_dec_proj(pooler_output_local) # else: encoder_attention_mask = None if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=pooler_output_local, encoder_attention_mask=encoder_attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[0] loss_fct = CrossEntropyLoss() loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) c_loss = None if return_contrastive_loss is not None and encoder_outputs is not None: decoder_logits = decoder_outputs.ss_logits if return_dict else decoder_outputs[0] encoder_logits = pooler_output_global loss_fct = CrossEntropyLoss(label_smoothing=label_smoothing) if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_logits = self.enc_to_dec_proj(encoder_logits) encoder_logits = nn.functional.normalize(encoder_logits) decoder_logits = nn.functional.normalize(decoder_logits) batch_size = encoder_logits.size(0) scores = torch.mm(decoder_logits, encoder_logits.t()) target = torch.arange(batch_size).to(decoder_logits.device) c_loss = loss_fct(scores/logit_temperature, target) + loss_fct(scores.t()/logit_temperature, target) if decoder_outputs.self_decoder_hidden_states is not None and decoder_outputs.cross_decoder_hidden_states is not None: decoder_hidden_states = decoder_outputs.self_decoder_hidden_states + decoder_outputs.cross_decoder_hidden_states else: decoder_hidden_states = None if decoder_outputs.self_decoder_attentions is not None and decoder_outputs.cross_decoder_attentions is not None: decoder_attentions = decoder_outputs.self_decoder_attentions + decoder_outputs.cross_decoder_attentions else: decoder_attentions = None if not return_dict: outputs = ( decoder_outputs.logits, pooler_output_global, pooler_output_local, decoder_outputs.ss_logits, decoder_outputs.past_key_values, decoder_hidden_states, decoder_attentions, decoder_outputs.cross_attentions, None if encoder_outputs is None else encoder_outputs.last_hidden_state, None if encoder_outputs is None else encoder_outputs.hidden_states, None if encoder_outputs is None else encoder_outputs.attentions, ) if c_loss is not None: outputs = (c_loss,) + outputs if loss is not None: return (loss,) + outputs else: return outputs return VELDDoubleHeadsOutput( loss=loss, c_loss=c_loss, logits=decoder_outputs.logits, e_logits_g=pooler_output_global, e_logits_l=pooler_output_local, d_logits=decoder_outputs.ss_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=None if encoder_outputs is None else encoder_outputs.last_hidden_state, encoder_hidden_states=None if encoder_outputs is None else encoder_outputs.hidden_states, encoder_attentions=None if encoder_outputs is None else encoder_outputs.attentions, ) def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self.decoder.prepare_decoder_input_ids_from_labels(labels) def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "decoder_input_ids": decoder_inputs["input_ids"], "encoder_outputs": encoder_outputs, "past_key_values": decoder_inputs["past_key_values"], "use_cache": use_cache, } return input_dict def resize_token_embeddings(self, *args, **kwargs): raise NotImplementedError( "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) def _reorder_cache(self, past, beam_idx): # 