diff --git "a/modeling_veld.py" "b/modeling_veld.py" new file mode 100644--- /dev/null +++ "b/modeling_veld.py" @@ -0,0 +1,2095 @@ +# Copyright 2022 san kim +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union, Callable + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +try: + from torch.nn import Identity +except ImportError: + # Older PyTorch compatibility + class Identity(nn.Module): + r"""A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs): + super().__init__() + + def forward(self, input): + return input + +from transformers.models.t5.modeling_t5 import ( + T5LayerSelfAttention, + T5LayerCrossAttention, + T5LayerFF, + T5PreTrainedModel, + T5LayerNorm, + PARALLELIZE_DOCSTRING, + DEPARALLELIZE_DOCSTRING, + __HEAD_MASK_WARNING_MSG, + T5_START_DOCSTRING, + T5_INPUTS_DOCSTRING +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + BaseModelOutput +) +from transformers.utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, + ModelOutput, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers import T5Config +from transformers.configuration_utils import PretrainedConfig +from transformers.activations import get_activation + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_DDT5 = "T5Config" + +def get_last_token_index(mask): + # attention masks: [batch_size, seq] + + batch_size, seq_length = mask.shape[:2] + incr = torch.arange(seq_length, device=mask.device, requires_grad=False) + incr_m = torch.einsum("i,ji->ji", incr, mask) + return torch.argmax(incr_m, dim=1) + +# modified from huggingface transformers lib (add attn pooling) +class SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: PretrainedConfig, num_queries=1): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + self.queries = nn.Parameter(torch.empty(num_queries, config.hidden_size)) + nn.init.kaiming_uniform_(self.queries, a=math.sqrt(5)) + self.MultiheadAttention = nn.MultiheadAttention( + config.hidden_size, + config.num_attention_heads, + batch_first=True + ) + layer_norm_eps = getattr(config, "layer_norm_eps", 1e-6) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps) + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + batch_size = hidden_states.size(0) + queries = self.queries.repeat(batch_size, 1, 1) + output = self.MultiheadAttention(queries, hidden_states, hidden_states, need_weights=False)[0] + + output = self.LayerNorm(output) + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + +# add_cross_attention +class T5DecoderBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_cross_attention = config.add_cross_attention + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.has_cross_attention: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.has_cross_attention and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentionsAndPositionBias(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) plus position bias. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + position_bias (`torch.FloatTensor`, *optional*, returned when the model is self-attention decoder): + position_bias is created in the first layer of the self-attention decoder, and it passes through all the layers including layers of the cross-attention decoder. + `torch.FloatTensor` of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + position_bias: Optional[torch.FloatTensor] = None + + +class T5DecoderStack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None, has_relative_attention_bias=True): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + self.has_cross_attention = config.add_cross_attention + + self.block = nn.ModuleList( + [T5DecoderBlock(config, has_relative_attention_bias=bool(i == 0) and has_relative_attention_bias) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") if self.embed_tokens is not None else self.embed_tokens + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + position_bias=None, + encoder_decoder_position_bias=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if self.has_cross_attention and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.has_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.has_cross_attention) else None + + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.has_cross_attention and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + outputs = tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + outputs = outputs + (position_bias,) + return outputs + return BaseModelOutputWithPastAndCrossAttentionsAndPositionBias( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + position_bias=position_bias + ) + + +@dataclass +class DualDecoderModelOutput(ModelOutput): + """ + Base class for model dual decoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the cross-attention decoder at the output of each layer. + cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the + cross-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. + self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. + self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class DualDecoderLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the cross-attention decoder at the output of each layer. + cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the + cross-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. + self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. + self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class DualDecoderDoubleHeadsOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + ss_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Global representaion of the self-attention decoder. The last token of sequence is used to calculate this representation. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the cross-attention decoder at the output of each layer. + cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the + cross-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model. + self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs. + self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + ss_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None + self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) +class T5DualDecoderLMHeadModel(T5PreTrainedModel): + + def __init__(self, config: T5Config, add_pooling_layer: bool = True): + config.is_encoder_decoder = False + config.is_decoder = True + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + self_decoder_config = copy.deepcopy(config) + self_decoder_config.is_decoder = True + self_decoder_config.is_encoder_decoder = False + self_decoder_config.add_cross_attention = False + # self.self_decoder = T5DecoderStack(self_decoder_config, self.shared) + self.encoder = T5DecoderStack(self_decoder_config, self.shared) + + cross_decoder_config = copy.deepcopy(config) + cross_decoder_config.is_decoder = True + cross_decoder_config.is_encoder_decoder = False + cross_decoder_config.add_cross_attention = True + cross_decoder_config.num_layers = config.num_decoder_layers + # self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False) + self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + # get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count())) + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + # assert_device_map(self.device_map, len(self.self_decoder.block)) + assert_device_map(self.device_map, len(self.encoder.block)) + # self.self_decoder.parallelize(self.device_map) + # self.cross_decoder.parallelize(self.device_map) + # self.lm_head = self.lm_head.to(self.cross_decoder.first_device) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + # self.self_decoder.deparallelize() + # self.cross_decoder.deparallelize() + # self.self_decoder = self.self_decoder.to("cpu") + # self.cross_decoder = self.cross_decoder.to("cpu") + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + # self.self_decoder.set_input_embeddings(new_embeddings) + # self.cross_decoder.set_input_embeddings(new_embeddings) + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + # return self.self_decoder + return self.encoder + + def get_decoder(self): + # return self.cross_decoder + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DualDecoderLMOutput, config_class=_CONFIG_FOR_DOC_DDT5) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + # encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], DualDecoderLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + Returns: + Examples: + ```python + >>> from transformers import T5Tokenizer, T5DualDecoderLMHeadModel + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + if past_key_values is not None: + self_decoder_past_key_value = past_key_values[0] + cross_decoder_past_key_value = past_key_values[1] + else: + self_decoder_past_key_value, cross_decoder_past_key_value = None, None + + if labels is not None and input_ids is None and inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + input_ids = self._shift_right(labels) + + # self attention decoder + # self_decoder_outputs = self.self_decoder( + self_decoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=self_decoder_past_key_value, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = self_decoder_outputs[0] + position_bias = self_decoder_outputs[-1] + + # get encoder hidden states + # encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] + # encoder_attention_mask = None + + # Set device for model parallelism + if self.model_parallel: + # torch.cuda.set_device(self.cross_decoder.first_device) + # hidden_states = hidden_states.to(self.cross_decoder.first_device) + # if attention_mask is not None: + # attention_mask = attention_mask.to(self.cross_decoder.first_device) + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + + + # cross attention decoder + # cross_decoder_outputs = self.cross_decoder( + cross_decoder_outputs = self.decoder( + attention_mask=attention_mask, + inputs_embeds=hidden_states, + position_bias=position_bias, + past_key_values=cross_decoder_past_key_value, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = cross_decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + # torch.cuda.set_device(self.self_decoder.first_device) + # self.lm_head = self.lm_head.to(self.self_decoder.first_device) + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: + past_key_values = None + else: + past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) + + if not return_dict: + output = (lm_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DualDecoderLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=past_key_values, + cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, + cross_decoder_attentions=cross_decoder_outputs.attentions, + cross_attentions=cross_decoder_outputs.cross_attentions, + self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, + self_decoder_hidden_states=self_decoder_outputs.hidden_states, + self_decoder_attentions=self_decoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + # encoder_outputs=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "past_key_values": past, + # "encoder_outputs": encoder_outputs, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) + + def _reorder_cache_single(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + + +@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING) +class T5DualDecoderDoubleHeadsModel(T5PreTrainedModel): + + def __init__(self, config: T5Config, add_pooling_layer: bool = True): + config.is_encoder_decoder = False + config.is_decoder = True + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + self_decoder_config = copy.deepcopy(config) + self_decoder_config.is_decoder = True + self_decoder_config.is_encoder_decoder = False + self_decoder_config.add_cross_attention = False + # self.self_decoder = T5DecoderStack(self_decoder_config, self.shared) + self.encoder = T5DecoderStack(self_decoder_config, self.shared) + + cross_decoder_config = copy.deepcopy(config) + cross_decoder_config.is_decoder = True + cross_decoder_config.is_encoder_decoder = False + cross_decoder_config.add_cross_attention = True + cross_decoder_config.num_layers = config.num_decoder_layers + # self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False) + self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + sequence_summary_config = copy.deepcopy(config) + sequence_summary_config.summary_type = "cls_index" + self.ss_head = SequenceSummary(config) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + # get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count())) + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + # assert_device_map(self.device_map, len(self.self_decoder.block)) + assert_device_map(self.device_map, len(self.encoder.block)) + # self.self_decoder.parallelize(self.device_map) + # self.cross_decoder.parallelize(self.device_map) + # self.lm_head = self.lm_head.to(self.cross_decoder.first_device) + # self.ss_head = self.ss_head.to(self.cross_decoder.first_device) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.ss_head = self.ss_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + # self.self_decoder.deparallelize() + # self.cross_decoder.deparallelize() + # self.self_decoder = self.self_decoder.to("cpu") + # self.cross_decoder = self.cross_decoder.to("cpu") + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.ss_head = self.ss_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + # self.self_decoder.set_input_embeddings(new_embeddings) + # self.cross_decoder.set_input_embeddings(new_embeddings) + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + # return self.self_decoder + return self.encoder + + def get_decoder(self): + # return self.cross_decoder + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DualDecoderDoubleHeadsOutput, config_class=_CONFIG_FOR_DOC_DDT5) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + # encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], DualDecoderDoubleHeadsOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + Returns: + Examples: + ```python + >>> from transformers import T5Tokenizer, T5DualDecoderDoubleHeadsModel + >>> tokenizer = T5Tokenizer.from_pretrained("veld-t5-base") + >>> model = T5DualDecoderDoubleHeadsModel.from_pretrained("veld-t5-base") + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + if past_key_values is not None: + self_decoder_past_key_value = past_key_values[0] + cross_decoder_past_key_value = past_key_values[1] + else: + self_decoder_past_key_value, cross_decoder_past_key_value = None, None + + if labels is not None and input_ids is None and inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + input_ids = self._shift_right(labels) + + # self attention decoder + # self_decoder_outputs = self.self_decoder( + self_decoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values=self_decoder_past_key_value, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = self_decoder_outputs[0] + position_bias = self_decoder_outputs[-1] + + # get encoder hidden states + # encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] + # encoder_attention_mask = None + + # Set device for model parallelism + if self.model_parallel: + # torch.cuda.set_device(self.cross_decoder.first_device) + # hidden_states = hidden_states.to(self.cross_decoder.first_device) + # if attention_mask is not None: + # attention_mask = attention_mask.to(self.cross_decoder.first_device) + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + + + # cross attention decoder + # cross_decoder_outputs = self.cross_decoder( + cross_decoder_outputs = self.decoder( + attention_mask=attention_mask, + inputs_embeds=hidden_states, + position_bias=position_bias, + past_key_values=cross_decoder_past_key_value, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = cross_decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + # torch.cuda.set_device(self.self_decoder.first_device) + # self.lm_head = self.lm_head.to(self.self_decoder.first_device) + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + # cls_index = None if attention_mask is None else get_last_token_index(attention_mask) + if self.config.pad_token_id is None: + cls_index = None + else: + if input_ids is not None: + cls_index = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + cls_index = None + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + ss_logits = self.ss_head(hidden_states, cls_index=cls_index) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None: + past_key_values = None + else: + past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values) + + if not return_dict: + output = (lm_logits, ss_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DualDecoderDoubleHeadsOutput( + loss=loss, + logits=lm_logits, + ss_logits=ss_logits, + past_key_values=past_key_values, + cross_decoder_hidden_states=cross_decoder_outputs.hidden_states, + cross_decoder_attentions=cross_decoder_outputs.attentions, + cross_attentions=cross_decoder_outputs.cross_attentions, + self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state, + self_decoder_hidden_states=self_decoder_outputs.hidden_states, + self_decoder_attentions=self_decoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + # encoder_outputs=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + **kwargs + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "past_key_values": past, + # "encoder_outputs": encoder_outputs, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx)) + + def _reorder_cache_single(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( + VISION_ENCODER_DECODER_START_DOCSTRING, + VISION_ENCODER_DECODER_INPUTS_DOCSTRING, +) +from transformers.models.auto.configuration_auto import AutoConfig +from transformers.models.auto.modeling_auto import AutoModel +from transformers import ViTModel, ViTConfig + +from .configuration_veld import VELDConfig + +_CONFIG_FOR_DOC_VELDT5 = "VELDConfig" + +@dataclass +class VELDDoubleHeadsOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + c_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + e_logits_g: torch.FloatTensor = None + e_logits_l: torch.FloatTensor = None + d_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + +@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING) +class VELDModel(PreTrainedModel): + r""" + [`VELDModel`] is a generic model class that will be instantiated as a transformer architecture with + one of the base vision model classes of the library as encoder and another one as dual decoder when created with the + :meth*~transformers.AutoModel.from_pretrained* class method for the encoder. + """ + config_class = VELDConfig + base_model_prefix = "veld" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = ViTModel(config.encoder, add_pooling_layer=False) + + if decoder is None: + decoder = T5DualDecoderDoubleHeadsModel(config.decoder) + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + + pooling_config = copy.deepcopy(self.encoder.config) + pooling_config.summary_type = "attn" + self.global_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_global) + self.local_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_local) + + + def _set_gradient_checkpointing(self, module, value=False): + # call both encoder and decoder function on gradient checkpointing + self.encoder._set_gradient_checkpointing(module, value=value) + self.decoder._set_gradient_checkpointing(module, value=value) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + # At the moment fast initialization is not supported for composite models + if kwargs.get("_fast_init", False): + logger.warning( + "Fast initialization is currently not supported for VELDModel. " + "Falling back to slow initialization..." + ) + kwargs["_fast_init"] = False + return super().from_pretrained(*args, **kwargs) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: str = None, + decoder_pretrained_model_name_or_path: str = None, + *model_args, + **kwargs + ) -> PreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the image encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An + example is `google/vit-base-patch16-224-in21k`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the text decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args (remaining positional arguments, *optional*): + All remaning positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from modeling_veld import VELDModel + + >>> # initialize a vit-t5 from a pretrained ViT and a pretrained T5 model. Note that the cross-attention layers will be randomly initialized + >>> model = VELDModel.from_encoder_decoder_pretrained( + ... "google/vit-base-patch16-224-in21k", "t5-base" + ... ) + >>> # saving model after fine-tuning + >>> model.save_pretrained("./vit-t5") + >>> # load fine-tuned model + >>> model = VELDModel.from_pretrained("./vit-t5") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = ViTConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = ViTModel.from_pretrained(encoder_pretrained_model_name_or_path, add_pooling_layer=False, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = T5Config.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = T5DualDecoderDoubleHeadsModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + return cls(encoder=encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC_VELDT5) + def forward( + self, + pixel_values=None, + decoder_input_ids=None, + decoder_attention_mask=None, + encoder_outputs=None, + past_key_values=None, + decoder_inputs_embeds=None, + labels=None, + return_contrastive_loss=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + logit_temperature=1.0, + label_smoothing=0.0, + **kwargs, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, ViTFeatureExtractor, VELDModel + >>> import requests + >>> from PIL import Image + >>> import torch + + >>> processor = ViTFeatureExtractor.from_pretrained("KETI-AIR/veld-base") + >>> tokenizer = AutoTokenizer.from_pretrained("KETI-AIR/veld-base") + >>> model = VELDModel.from_pretrained("KETI-AIR/veld-base") + + >>> # load image from the IAM dataset + >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + + >>> # training + >>> pixel_values = processor(image, return_tensors="pt").pixel_values + >>> text = "hello world" + >>> labels = tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(pixel_values=pixel_values, labels=labels) + >>> loss = outputs.loss + + >>> # inference (generation) + >>> generated_ids = model.generate(pixel_values, max_new_tokens=20) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None and pixel_values is not None: + # if pixel_values is None: + # raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0] + pooler_output_local = None if encoder_outputs is None else self.local_pooling(encoder_hidden_states) + pooler_output_global = None if encoder_outputs is None else self.global_pooling(pooler_output_local).squeeze(1) + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + and pooler_output_local is not None + ): + pooler_output_local = self.enc_to_dec_proj(pooler_output_local) + + + # else: + encoder_attention_mask = None + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=pooler_output_local, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) + + c_loss = None + if return_contrastive_loss is not None and encoder_outputs is not None: + decoder_logits = decoder_outputs.ss_logits if return_dict else decoder_outputs[0] + encoder_logits = pooler_output_global + loss_fct = CrossEntropyLoss(label_smoothing=label_smoothing) + + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_logits = self.enc_to_dec_proj(encoder_logits) + + + encoder_logits = nn.functional.normalize(encoder_logits) + decoder_logits = nn.functional.normalize(decoder_logits) + + batch_size = encoder_logits.size(0) + scores = torch.mm(decoder_logits, encoder_logits.t()) + target = torch.arange(batch_size).to(decoder_logits.device) + + c_loss = loss_fct(scores/logit_temperature, target) + loss_fct(scores.t()/logit_temperature, target) + + + if decoder_outputs.self_decoder_hidden_states is not None and decoder_outputs.cross_decoder_hidden_states is not None: + decoder_hidden_states = decoder_outputs.self_decoder_hidden_states + decoder_outputs.cross_decoder_hidden_states + else: + decoder_hidden_states = None + + if decoder_outputs.self_decoder_attentions is not None and decoder_outputs.cross_decoder_attentions is not None: + decoder_attentions = decoder_outputs.self_decoder_attentions + decoder_outputs.cross_decoder_attentions + else: + decoder_attentions = None + + if not return_dict: + outputs = ( + decoder_outputs.logits, + pooler_output_global, + pooler_output_local, + decoder_outputs.ss_logits, + decoder_outputs.past_key_values, + decoder_hidden_states, + decoder_attentions, + decoder_outputs.cross_attentions, + None if encoder_outputs is None else encoder_outputs.last_hidden_state, + None if encoder_outputs is None else encoder_outputs.hidden_states, + None if encoder_outputs is None else encoder_outputs.attentions, + ) + if c_loss is not None: + outputs = (c_loss,) + outputs + if loss is not None: + return (loss,) + outputs + else: + return outputs + + return VELDDoubleHeadsOutput( + loss=loss, + c_loss=c_loss, + logits=decoder_outputs.logits, + e_logits_g=pooler_output_global, + e_logits_l=pooler_output_local, + d_logits=decoder_outputs.ss_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=None if encoder_outputs is None else encoder_outputs.last_hidden_state, + encoder_hidden_states=None if encoder_outputs is None else encoder_outputs.hidden_states, + encoder_attentions=None if encoder_outputs is None else encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self.decoder.prepare_decoder_input_ids_from_labels(labels) + + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + "encoder_outputs": encoder_outputs, + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, + } + return input_dict + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past, beam_idx) + + + + +if __name__ == "__main__": + from transformers import AutoTokenizer, ViTFeatureExtractor + from PIL import Image + + VISION_PRETRAINED_MODEL = "google/vit-base-patch16-384" + LANGUAGE_PRETRAINED_MODEL = "KETI-AIR/ke-t5-base" + + test_inputs = [ + "To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.", + "To update the parent model configuration,", + ] + tokenizer = AutoTokenizer.from_pretrained(LANGUAGE_PRETRAINED_MODEL) + inps = tokenizer(test_inputs, padding=True, truncation="longest_first", return_tensors="pt") + + + # config = T5Config.from_pretrained(LANGUAGE_PRETRAINED_MODEL) + # model = T5DualDecoderDoubleHeadsModel.from_pretrained(LANGUAGE_PRETRAINED_MODEL) + + # print(inps.input_ids.size()) + # outputs = model( + # input_ids=inps.input_ids, + # attention_mask=inps.attention_mask, + # ) + + # with torch.no_grad(): + # print(model.generate(inps.input_ids, num_beams=4, max_new_tokens=20)) + + + + + + + feature_extractor = ViTFeatureExtractor.from_pretrained(VISION_PRETRAINED_MODEL) + images = [Image.open("images/sample.jpg"), Image.open("images/sample2.jpg")] + pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values + + model = VELDModel.from_encoder_decoder_pretrained( + VISION_PRETRAINED_MODEL, + LANGUAGE_PRETRAINED_MODEL + ) + + outputs = model( + labels=inps.input_ids, + return_contrastive_loss=True, + decoder_attention_mask=inps.attention_mask + ) + print(outputs.loss) + print(outputs.c_loss) + + outputs = model( + pixel_values=pixel_values, + labels=inps.input_ids, + return_contrastive_loss=True, + decoder_attention_mask=inps.attention_mask) + print(outputs.loss) + print(outputs.c_loss) + + # print(outputs) + + # outputs = model.generate( + # pixel_values=pixel_values, + # decoder_input_ids=inps.input_ids, + # decoder_attention_mask=inps.attention_mask, + # num_beams=4, + # max_new_tokens=20 + # ) + # print(outputs)