veld-base / modeling_veld.py
kimsan0622's picture
Upload model
eb96320
# Copyright 2022 san kim
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union, Callable
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
try:
from torch.nn import Identity
except ImportError:
# Older PyTorch compatibility
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive."""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input):
return input
from transformers.models.t5.modeling_t5 import (
T5LayerSelfAttention,
T5LayerCrossAttention,
T5LayerFF,
T5PreTrainedModel,
T5LayerNorm,
PARALLELIZE_DOCSTRING,
DEPARALLELIZE_DOCSTRING,
__HEAD_MASK_WARNING_MSG,
T5_START_DOCSTRING,
T5_INPUTS_DOCSTRING
)
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
BaseModelOutput
)
from transformers.utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
replace_return_docstrings,
ModelOutput,
)
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from transformers import T5Config
from transformers.configuration_utils import PretrainedConfig
from transformers.activations import get_activation
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC_DDT5 = "T5Config"
def get_last_token_index(mask):
# attention masks: [batch_size, seq]
batch_size, seq_length = mask.shape[:2]
incr = torch.arange(seq_length, device=mask.device, requires_grad=False)
incr_m = torch.einsum("i,ji->ji", incr, mask)
return torch.argmax(incr_m, dim=1)
# modified from huggingface transformers lib (add attn pooling)
class SequenceSummary(nn.Module):
r"""
Compute a single vector summary of a sequence hidden states.
Args:
config ([`PretrainedConfig`]):
The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
config class of your model for the default values it uses):
- **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
- `"last"` -- Take the last token hidden state (like XLNet)
- `"first"` -- Take the first token hidden state (like Bert)
- `"mean"` -- Take the mean of all tokens hidden states
- `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
- `"attn"` -- Not implemented now, use multi-head attention
- **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
- **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
(otherwise to `config.hidden_size`).
- **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
another string or `None` will add no activation.
- **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
- **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
"""
def __init__(self, config: PretrainedConfig, num_queries=1):
super().__init__()
self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
self.queries = nn.Parameter(torch.empty(num_queries, config.hidden_size))
nn.init.kaiming_uniform_(self.queries, a=math.sqrt(5))
self.MultiheadAttention = nn.MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
batch_first=True
)
layer_norm_eps = getattr(config, "layer_norm_eps", 1e-6)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps)
self.summary = Identity()
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
num_classes = config.num_labels
else:
num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
self.first_dropout = nn.Dropout(config.summary_first_dropout)
self.last_dropout = Identity()
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
self.last_dropout = nn.Dropout(config.summary_last_dropout)
def forward(
self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
) -> torch.FloatTensor:
"""
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
Returns:
`torch.FloatTensor`: The summary of the sequence hidden states.
"""
if self.summary_type == "last":
output = hidden_states[:, -1]
elif self.summary_type == "first":
output = hidden_states[:, 0]
elif self.summary_type == "mean":
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
elif self.summary_type == "attn":
batch_size = hidden_states.size(0)
queries = self.queries.repeat(batch_size, 1, 1)
output = self.MultiheadAttention(queries, hidden_states, hidden_states, need_weights=False)[0]
output = self.LayerNorm(output)
output = self.first_dropout(output)
output = self.summary(output)
output = self.activation(output)
output = self.last_dropout(output)
return output
# add_cross_attention
class T5DecoderBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_cross_attention = config.add_cross_attention
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.has_cross_attention:
self.layer.append(T5LayerCrossAttention(config))
self.layer.append(T5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.has_cross_attention and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
@dataclass
class BaseModelOutputWithPastAndCrossAttentionsAndPositionBias(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) plus position bias.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
position_bias (`torch.FloatTensor`, *optional*, returned when the model is self-attention decoder):
position_bias is created in the first layer of the self-attention decoder, and it passes through all the layers including layers of the cross-attention decoder.
`torch.FloatTensor` of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
position_bias: Optional[torch.FloatTensor] = None
class T5DecoderStack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None, has_relative_attention_bias=True):
super().__init__(config)
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.has_cross_attention = config.add_cross_attention
self.block = nn.ModuleList(
[T5DecoderBlock(config, has_relative_attention_bias=bool(i == 0) and has_relative_attention_bias) for i in range(config.num_layers)]
)
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
# Check validity of device_map
self.device_map = (
get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
)
assert_device_map(self.device_map, len(self.block))
self.model_parallel = True
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
self.last_device = "cuda:" + str(max(self.device_map.keys()))
# Load onto devices
for k, v in self.device_map.items():
for layer in v:
cuda_device = "cuda:" + str(k)
self.block[layer] = self.block[layer].to(cuda_device)
# Set embed_tokens to first layer
self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens
# Set final layer norm to last device
self.final_layer_norm = self.final_layer_norm.to(self.last_device)
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def deparallelize(self):
self.model_parallel = False
self.device_map = None
self.first_device = "cpu"
self.last_device = "cpu"
for i in range(len(self.block)):
self.block[i] = self.block[i].to("cpu")
self.embed_tokens = self.embed_tokens.to("cpu") if self.embed_tokens is not None else self.embed_tokens
self.final_layer_norm = self.final_layer_norm.to("cpu")
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
position_bias=None,
encoder_decoder_position_bias=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(self.first_device)
self.embed_tokens = self.embed_tokens.to(self.first_device) if self.embed_tokens is not None else self.embed_tokens
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.has_cross_attention and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.has_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.has_cross_attention) else None
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device)
if cross_attn_layer_head_mask is not None:
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2]
if self.has_cross_attention and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.has_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
outputs = tuple(
v
for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
]
if v is not None
)
outputs = outputs + (position_bias,)
return outputs
return BaseModelOutputWithPastAndCrossAttentionsAndPositionBias(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
position_bias=position_bias
)
@dataclass
class DualDecoderModelOutput(ModelOutput):
"""
Base class for model dual decoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the cross-attention decoder at the output of each layer.
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the
cross-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model.
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs.
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DualDecoderLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the cross-attention decoder at the output of each layer.
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the
cross-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model.
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs.
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class DualDecoderDoubleHeadsOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
ss_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Global representaion of the self-attention decoder. The last token of sequence is used to calculate this representation.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
cross_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the cross-attention decoder at the output of each layer.
cross_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the cross-attention decoder, after the attention softmax, used to compute the weighted average in the
cross-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
self_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the self-attention decoder of the model.
self_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the self-attention decoder at the output of each layer plus the optional initial embedding outputs.
self_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the self-attention decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
ss_logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
cross_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
cross_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
self_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
self_decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5DualDecoderLMHeadModel(T5PreTrainedModel):
def __init__(self, config: T5Config, add_pooling_layer: bool = True):
config.is_encoder_decoder = False
config.is_decoder = True
super().__init__(config)
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self_decoder_config = copy.deepcopy(config)
self_decoder_config.is_decoder = True
self_decoder_config.is_encoder_decoder = False
self_decoder_config.add_cross_attention = False
# self.self_decoder = T5DecoderStack(self_decoder_config, self.shared)
self.encoder = T5DecoderStack(self_decoder_config, self.shared)
cross_decoder_config = copy.deepcopy(config)
cross_decoder_config.is_decoder = True
cross_decoder_config.is_encoder_decoder = False
cross_decoder_config.add_cross_attention = True
cross_decoder_config.num_layers = config.num_decoder_layers
# self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False)
self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
# get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count()))
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# assert_device_map(self.device_map, len(self.self_decoder.block))
assert_device_map(self.device_map, len(self.encoder.block))
# self.self_decoder.parallelize(self.device_map)
# self.cross_decoder.parallelize(self.device_map)
# self.lm_head = self.lm_head.to(self.cross_decoder.first_device)
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.decoder.first_device)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
# self.self_decoder.deparallelize()
# self.cross_decoder.deparallelize()
# self.self_decoder = self.self_decoder.to("cpu")
# self.cross_decoder = self.cross_decoder.to("cpu")
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
# self.self_decoder.set_input_embeddings(new_embeddings)
# self.cross_decoder.set_input_embeddings(new_embeddings)
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_encoder(self):
# return self.self_decoder
return self.encoder
def get_decoder(self):
# return self.cross_decoder
return self.decoder
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DualDecoderLMOutput, config_class=_CONFIG_FOR_DOC_DDT5)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
# encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], DualDecoderLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
Returns:
Examples:
```python
>>> from transformers import T5Tokenizer, T5DualDecoderLMHeadModel
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
>>> # training
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> # inference
>>> input_ids = tokenizer(
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
```"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
if past_key_values is not None:
self_decoder_past_key_value = past_key_values[0]
cross_decoder_past_key_value = past_key_values[1]
else:
self_decoder_past_key_value, cross_decoder_past_key_value = None, None
if labels is not None and input_ids is None and inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
input_ids = self._shift_right(labels)
# self attention decoder
# self_decoder_outputs = self.self_decoder(
self_decoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=self_decoder_past_key_value,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = self_decoder_outputs[0]
position_bias = self_decoder_outputs[-1]
# get encoder hidden states
# encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0]
# encoder_attention_mask = None
# Set device for model parallelism
if self.model_parallel:
# torch.cuda.set_device(self.cross_decoder.first_device)
# hidden_states = hidden_states.to(self.cross_decoder.first_device)
# if attention_mask is not None:
# attention_mask = attention_mask.to(self.cross_decoder.first_device)
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
# cross attention decoder
# cross_decoder_outputs = self.cross_decoder(
cross_decoder_outputs = self.decoder(
attention_mask=attention_mask,
inputs_embeds=hidden_states,
position_bias=position_bias,
past_key_values=cross_decoder_past_key_value,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = cross_decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
# torch.cuda.set_device(self.self_decoder.first_device)
# self.lm_head = self.lm_head.to(self.self_decoder.first_device)
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None:
past_key_values = None
else:
past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values)
if not return_dict:
output = (lm_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:]
return ((loss,) + output) if loss is not None else output
return DualDecoderLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=past_key_values,
cross_decoder_hidden_states=cross_decoder_outputs.hidden_states,
cross_decoder_attentions=cross_decoder_outputs.attentions,
cross_attentions=cross_decoder_outputs.cross_attentions,
self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state,
self_decoder_hidden_states=self_decoder_outputs.hidden_states,
self_decoder_attentions=self_decoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
# encoder_outputs=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past,
# "encoder_outputs": encoder_outputs,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx))
def _reorder_cache_single(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
@add_start_docstrings("""T5 Dual Decoder with a `language modeling` head on top.""", T5_START_DOCSTRING)
class T5DualDecoderDoubleHeadsModel(T5PreTrainedModel):
def __init__(self, config: T5Config, add_pooling_layer: bool = True):
config.is_encoder_decoder = False
config.is_decoder = True
super().__init__(config)
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self_decoder_config = copy.deepcopy(config)
self_decoder_config.is_decoder = True
self_decoder_config.is_encoder_decoder = False
self_decoder_config.add_cross_attention = False
# self.self_decoder = T5DecoderStack(self_decoder_config, self.shared)
self.encoder = T5DecoderStack(self_decoder_config, self.shared)
cross_decoder_config = copy.deepcopy(config)
cross_decoder_config.is_decoder = True
cross_decoder_config.is_encoder_decoder = False
cross_decoder_config.add_cross_attention = True
cross_decoder_config.num_layers = config.num_decoder_layers
# self.cross_decoder = T5DecoderStack(cross_decoder_config, has_relative_attention_bias=False)
self.decoder = T5DecoderStack(cross_decoder_config, self.shared, has_relative_attention_bias=False)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
sequence_summary_config = copy.deepcopy(config)
sequence_summary_config.summary_type = "cls_index"
self.ss_head = SequenceSummary(config)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
self.device_map = (
# get_device_map(len(self.self_decoder.block), range(torch.cuda.device_count()))
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
if device_map is None
else device_map
)
# assert_device_map(self.device_map, len(self.self_decoder.block))
assert_device_map(self.device_map, len(self.encoder.block))
# self.self_decoder.parallelize(self.device_map)
# self.cross_decoder.parallelize(self.device_map)
# self.lm_head = self.lm_head.to(self.cross_decoder.first_device)
# self.ss_head = self.ss_head.to(self.cross_decoder.first_device)
self.encoder.parallelize(self.device_map)
self.decoder.parallelize(self.device_map)
self.lm_head = self.lm_head.to(self.decoder.first_device)
self.ss_head = self.ss_head.to(self.decoder.first_device)
self.model_parallel = True
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
def deparallelize(self):
# self.self_decoder.deparallelize()
# self.cross_decoder.deparallelize()
# self.self_decoder = self.self_decoder.to("cpu")
# self.cross_decoder = self.cross_decoder.to("cpu")
self.encoder.deparallelize()
self.decoder.deparallelize()
self.encoder = self.encoder.to("cpu")
self.decoder = self.decoder.to("cpu")
self.lm_head = self.lm_head.to("cpu")
self.ss_head = self.ss_head.to("cpu")
self.model_parallel = False
self.device_map = None
torch.cuda.empty_cache()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
# self.self_decoder.set_input_embeddings(new_embeddings)
# self.cross_decoder.set_input_embeddings(new_embeddings)
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_encoder(self):
# return self.self_decoder
return self.encoder
def get_decoder(self):
# return self.cross_decoder
return self.decoder
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=DualDecoderDoubleHeadsOutput, config_class=_CONFIG_FOR_DOC_DDT5)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
# encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], DualDecoderDoubleHeadsOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
Returns:
Examples:
```python
>>> from transformers import T5Tokenizer, T5DualDecoderDoubleHeadsModel
>>> tokenizer = T5Tokenizer.from_pretrained("veld-t5-base")
>>> model = T5DualDecoderDoubleHeadsModel.from_pretrained("veld-t5-base")
>>> # training
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> # inference
>>> input_ids = tokenizer(
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
```"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
if past_key_values is not None:
self_decoder_past_key_value = past_key_values[0]
cross_decoder_past_key_value = past_key_values[1]
else:
self_decoder_past_key_value, cross_decoder_past_key_value = None, None
if labels is not None and input_ids is None and inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
input_ids = self._shift_right(labels)
# self attention decoder
# self_decoder_outputs = self.self_decoder(
self_decoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=self_decoder_past_key_value,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = self_decoder_outputs[0]
position_bias = self_decoder_outputs[-1]
# get encoder hidden states
# encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0]
# encoder_attention_mask = None
# Set device for model parallelism
if self.model_parallel:
# torch.cuda.set_device(self.cross_decoder.first_device)
# hidden_states = hidden_states.to(self.cross_decoder.first_device)
# if attention_mask is not None:
# attention_mask = attention_mask.to(self.cross_decoder.first_device)
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
# cross attention decoder
# cross_decoder_outputs = self.cross_decoder(
cross_decoder_outputs = self.decoder(
attention_mask=attention_mask,
inputs_embeds=hidden_states,
position_bias=position_bias,
past_key_values=cross_decoder_past_key_value,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = cross_decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
# torch.cuda.set_device(self.self_decoder.first_device)
# self.lm_head = self.lm_head.to(self.self_decoder.first_device)
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
# cls_index = None if attention_mask is None else get_last_token_index(attention_mask)
if self.config.pad_token_id is None:
cls_index = None
else:
if input_ids is not None:
cls_index = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
else:
cls_index = None
logger.warning(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
ss_logits = self.ss_head(hidden_states, cls_index=cls_index)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if self_decoder_outputs.past_key_values is None or cross_decoder_outputs.past_key_values is None:
past_key_values = None
else:
past_key_values=(self_decoder_outputs.past_key_values, cross_decoder_outputs.past_key_values)
if not return_dict:
output = (lm_logits, ss_logits, past_key_values) + cross_decoder_outputs[2:] + (self_decoder_outputs[0],) + self_decoder_outputs[2:]
return ((loss,) + output) if loss is not None else output
return DualDecoderDoubleHeadsOutput(
loss=loss,
logits=lm_logits,
ss_logits=ss_logits,
past_key_values=past_key_values,
cross_decoder_hidden_states=cross_decoder_outputs.hidden_states,
cross_decoder_attentions=cross_decoder_outputs.attentions,
cross_attentions=cross_decoder_outputs.cross_attentions,
self_decoder_last_hidden_state=self_decoder_outputs.last_hidden_state,
self_decoder_hidden_states=self_decoder_outputs.hidden_states,
self_decoder_attentions=self_decoder_outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
# encoder_outputs=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
**kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past,
# "encoder_outputs": encoder_outputs,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
return (self._reorder_cache_single(past[0], beam_idx), self._reorder_cache_single(past[1], beam_idx))
def _reorder_cache_single(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import (
VISION_ENCODER_DECODER_START_DOCSTRING,
VISION_ENCODER_DECODER_INPUTS_DOCSTRING,
)
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel
from transformers import ViTModel, ViTConfig
from .configuration_veld import VELDConfig
_CONFIG_FOR_DOC_VELDT5 = "VELDConfig"
@dataclass
class VELDDoubleHeadsOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
c_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
e_logits_g: torch.FloatTensor = None
e_logits_l: torch.FloatTensor = None
d_logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@add_start_docstrings(VISION_ENCODER_DECODER_START_DOCSTRING)
class VELDModel(PreTrainedModel):
r"""
[`VELDModel`] is a generic model class that will be instantiated as a transformer architecture with
one of the base vision model classes of the library as encoder and another one as dual decoder when created with the
:meth*~transformers.AutoModel.from_pretrained* class method for the encoder.
"""
config_class = VELDConfig
base_model_prefix = "veld"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
if config is None and (encoder is None or decoder is None):
raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
if config is None:
config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
if not isinstance(config, self.config_class):
raise ValueError(f"Config: {config} has to be of type {self.config_class}")
if config.decoder.cross_attention_hidden_size is not None:
if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
raise ValueError(
"If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
" `config.encoder.hidden_size`."
)
# initialize with config
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
super().__init__(config)
if encoder is None:
encoder = ViTModel(config.encoder, add_pooling_layer=False)
if decoder is None:
decoder = T5DualDecoderDoubleHeadsModel(config.decoder)
self.encoder = encoder
self.decoder = decoder
if self.encoder.config.to_dict() != self.config.encoder.to_dict():
logger.warning(
f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
f" {self.config.encoder}"
)
if self.decoder.config.to_dict() != self.config.decoder.to_dict():
logger.warning(
f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
f" {self.config.decoder}"
)
# make sure that the individual model's config refers to the shared config
# so that the updates to the config will be synced
self.encoder.config = self.config.encoder
self.decoder.config = self.config.decoder
# encoder outputs might need to be projected to different dimension for decoder
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
if self.encoder.get_output_embeddings() is not None:
raise ValueError(
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
pooling_config = copy.deepcopy(self.encoder.config)
pooling_config.summary_type = "attn"
self.global_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_global)
self.local_pooling = SequenceSummary(pooling_config, num_queries=self.config.num_queries_local)
def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported for composite models
if kwargs.get("_fast_init", False):
logger.warning(
"Fast initialization is currently not supported for VELDModel. "
"Falling back to slow initialization..."
)
kwargs["_fast_init"] = False
return super().from_pretrained(*args, **kwargs)
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs
) -> PreTrainedModel:
r"""
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
the model, you need to first set it back in training mode with `model.train()`.
Params:
encoder_pretrained_model_name_or_path (`str`, *optional*):
Information necessary to initiate the image encoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. An
example is `google/vit-base-patch16-224-in21k`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
Information necessary to initiate the text decoder. Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
this case, `from_tf` should be set to `True` and a configuration object should be provided as
`config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args (remaining positional arguments, *optional*):
All remaning positional arguments will be passed to the underlying model's `__init__` method.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`).
- To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
- To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a `config` is provided or automatically loaded.
Example:
```python
>>> from modeling_veld import VELDModel
>>> # initialize a vit-t5 from a pretrained ViT and a pretrained T5 model. Note that the cross-attention layers will be randomly initialized
>>> model = VELDModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "t5-base"
... )
>>> # saving model after fine-tuning
>>> model.save_pretrained("./vit-t5")
>>> # load fine-tuned model
>>> model = VELDModel.from_pretrained("./vit-t5")
```"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
if encoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_encoder:
encoder_config, kwargs_encoder = ViTConfig.from_pretrained(
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
"from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = ViTModel.from_pretrained(encoder_pretrained_model_name_or_path, add_pooling_layer=False, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
if decoder_pretrained_model_name_or_path is None:
raise ValueError(
"If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
"to be defined."
)
if "config" not in kwargs_decoder:
decoder_config, kwargs_decoder = T5Config.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
"make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
"passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = T5DualDecoderDoubleHeadsModel.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = VELDConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
# make sure input & output embeddings is not tied
config.tie_word_embeddings = False
return cls(encoder=encoder, decoder=decoder, config=config)
@add_start_docstrings_to_model_forward(VISION_ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC_VELDT5)
def forward(
self,
pixel_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
return_contrastive_loss=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
logit_temperature=1.0,
label_smoothing=0.0,
**kwargs,
):
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, ViTFeatureExtractor, VELDModel
>>> import requests
>>> from PIL import Image
>>> import torch
>>> processor = ViTFeatureExtractor.from_pretrained("KETI-AIR/veld-base")
>>> tokenizer = AutoTokenizer.from_pretrained("KETI-AIR/veld-base")
>>> model = VELDModel.from_pretrained("KETI-AIR/veld-base")
>>> # load image from the IAM dataset
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> # training
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
>>> text = "hello world"
>>> labels = tokenizer(text, return_tensors="pt").input_ids
>>> outputs = model(pixel_values=pixel_values, labels=labels)
>>> loss = outputs.loss
>>> # inference (generation)
>>> generated_ids = model.generate(pixel_values, max_new_tokens=20)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None and pixel_values is not None:
# if pixel_values is None:
# raise ValueError("You have to specify pixel_values")
encoder_outputs = self.encoder(
pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
)
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = None if encoder_outputs is None else encoder_outputs[0]
pooler_output_local = None if encoder_outputs is None else self.local_pooling(encoder_hidden_states)
pooler_output_global = None if encoder_outputs is None or return_contrastive_loss is None else self.global_pooling(pooler_output_local).squeeze(1)
# optionally project encoder_hidden_states
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
and pooler_output_local is not None
):
pooler_output_local = self.enc_to_dec_proj(pooler_output_local)
# else:
encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=pooler_output_local,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
c_loss = None
if return_contrastive_loss is not None and encoder_outputs is not None:
decoder_logits = decoder_outputs.ss_logits if return_dict else decoder_outputs[0]
encoder_logits = pooler_output_global
loss_fct = CrossEntropyLoss(label_smoothing=label_smoothing)
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_logits = self.enc_to_dec_proj(encoder_logits)
encoder_logits = nn.functional.normalize(encoder_logits)
decoder_logits = nn.functional.normalize(decoder_logits)
batch_size = encoder_logits.size(0)
scores = torch.mm(decoder_logits, encoder_logits.t())
target = torch.arange(batch_size).to(decoder_logits.device)
c_loss = loss_fct(scores/logit_temperature, target) + loss_fct(scores.t()/logit_temperature, target)
if decoder_outputs.self_decoder_hidden_states is not None and decoder_outputs.cross_decoder_hidden_states is not None:
decoder_hidden_states = decoder_outputs.self_decoder_hidden_states + decoder_outputs.cross_decoder_hidden_states
else:
decoder_hidden_states = None
if decoder_outputs.self_decoder_attentions is not None and decoder_outputs.cross_decoder_attentions is not None:
decoder_attentions = decoder_outputs.self_decoder_attentions + decoder_outputs.cross_decoder_attentions
else:
decoder_attentions = None
if not return_dict:
outputs = (
decoder_outputs.logits,
pooler_output_global,
pooler_output_local,
decoder_outputs.ss_logits,
decoder_outputs.past_key_values,
decoder_hidden_states,
decoder_attentions,
decoder_outputs.cross_attentions,
None if encoder_outputs is None else encoder_outputs.last_hidden_state,
None if encoder_outputs is None else encoder_outputs.hidden_states,
None if encoder_outputs is None else encoder_outputs.attentions,
)
if c_loss is not None:
outputs = (c_loss,) + outputs
if loss is not None:
return (loss,) + outputs
else:
return outputs
return VELDDoubleHeadsOutput(
loss=loss,
c_loss=c_loss,
logits=decoder_outputs.logits,
e_logits_g=pooler_output_global,
e_logits_l=pooler_output_local,
d_logits=decoder_outputs.ss_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=None if encoder_outputs is None else encoder_outputs.last_hidden_state,
encoder_hidden_states=None if encoder_outputs is None else encoder_outputs.hidden_states,
encoder_attentions=None if encoder_outputs is None else encoder_outputs.attentions,
)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self.decoder.prepare_decoder_input_ids_from_labels(labels)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
}
return input_dict
def resize_token_embeddings(self, *args, **kwargs):
raise NotImplementedError(
"Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the"
" respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))"
)
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
if __name__ == "__main__":
from transformers import AutoTokenizer, ViTFeatureExtractor
from PIL import Image
VISION_PRETRAINED_MODEL = "google/vit-base-patch16-384"
LANGUAGE_PRETRAINED_MODEL = "KETI-AIR/ke-t5-base"
test_inputs = [
"To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.",
"To update the parent model configuration,",
]
tokenizer = AutoTokenizer.from_pretrained(LANGUAGE_PRETRAINED_MODEL)
inps = tokenizer(test_inputs, padding=True, truncation="longest_first", return_tensors="pt")
# config = T5Config.from_pretrained(LANGUAGE_PRETRAINED_MODEL)
# model = T5DualDecoderDoubleHeadsModel.from_pretrained(LANGUAGE_PRETRAINED_MODEL)
# print(inps.input_ids.size())
# outputs = model(
# input_ids=inps.input_ids,
# attention_mask=inps.attention_mask,
# )
# with torch.no_grad():
# print(model.generate(inps.input_ids, num_beams=4, max_new_tokens=20))
feature_extractor = ViTFeatureExtractor.from_pretrained(VISION_PRETRAINED_MODEL)
images = [Image.open("images/sample.jpg"), Image.open("images/sample2.jpg")]
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
model = VELDModel.from_encoder_decoder_pretrained(
VISION_PRETRAINED_MODEL,
LANGUAGE_PRETRAINED_MODEL
)
# outputs = model(
# labels=inps.input_ids,
# return_contrastive_loss=True,
# decoder_attention_mask=inps.attention_mask
# )
# print(outputs.loss)
# print(outputs.c_loss)
# outputs = model(
# pixel_values=pixel_values,
# labels=inps.input_ids,
# return_contrastive_loss=True,
# decoder_attention_mask=inps.attention_mask)
# print(outputs.loss)
# print(outputs.c_loss)
# outputs = model(
# pixel_values=pixel_values,
# labels=inps.input_ids,
# decoder_attention_mask=inps.attention_mask)
# print(outputs.loss)
# print(outputs.c_loss)
# print(outputs)
outputs = model.generate(
pixel_values=pixel_values,
decoder_input_ids=inps.input_ids,
decoder_attention_mask=inps.attention_mask,
num_beams=4,
max_new_tokens=20
)
# print(outputs)