t5-vae-python / model /outputs.py
Fraser's picture
add transformer-vae code
0b69648
raw history blame
No virus
5.06 kB
from typing import Optional, Tuple
import flax
import jaxlib.xla_extension as jax_xla
from transformers.file_utils import ModelOutput
@flax.struct.dataclass
class TransformerVaeOutput(ModelOutput):
"""
Base class for a Transformer-VAE's outputs.
Args:
latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
Latent codes representing encoded sequences.
remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
Reconstructed encoder hidden states representing sequences.
(std Seq2Seq) Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`:
Last model hidden state.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: jax_xla.DeviceArray = None
latent_codes: jax_xla.DeviceArray = None
remade_encoder_hidden_state: jax_xla.DeviceArray = None
# seq2seq
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None