|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
from transformers.file_utils import ModelOutput |
|
|
|
@dataclass |
|
class TranceptionCausalLMOutputWithCrossAttentions(ModelOutput): |
|
""" |
|
Class for Tranception causal language model (or autoregressive) outputs. |
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of |
|
shape `(batch_size, sequence_length, hidden_size)`. |
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
Cross attentions weights after the attention softmax, used to compute the weighted average in the |
|
cross-attention heads. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, |
|
value states of the self-attention and the cross-attention layers if model is used in encoder-decoder |
|
setting. Only relevant if `config.is_decoder = True`. |
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
fused_shift_log_probas (`torch.FloatTensor` of shape (batch_size, sequence_length, config.vocab_size), *optional*, returned when config.retrieval_aggregation_mode is not None. |
|
log_probas for each residue position after aggregating autoregressive logits and retrieval logits. |
|
|
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
fused_shift_log_probas: Optional[torch.FloatTensor] = None |