Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple | |
import flax | |
import jaxlib.xla_extension as jax_xla | |
from transformers.file_utils import ModelOutput | |
class TransformerVaeOutput(ModelOutput): | |
""" | |
Base class for a Transformer-VAE's outputs. | |
Args: | |
latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`): | |
Latent codes representing encoded sequences. | |
remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`): | |
Reconstructed encoder hidden states representing sequences. | |
(std Seq2Seq) Args: | |
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): | |
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2 | |
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional | |
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | |
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. | |
last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`: | |
Last model hidden state. | |
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each | |
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. | |
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, | |
sequence_length, sequence_length)`. | |
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the | |
self-attention heads. | |
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, | |
sequence_length, sequence_length)`. | |
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the | |
weighted average in the cross-attention heads. | |
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): | |
Sequence of hidden-states at the output of the last layer of the encoder of the model. | |
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each | |
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. | |
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads, | |
sequence_length, sequence_length)`. | |
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the | |
self-attention heads. | |
""" | |
logits: jax_xla.DeviceArray = None | |
latent_codes: jax_xla.DeviceArray = None | |
remade_encoder_hidden_state: jax_xla.DeviceArray = None | |
# seq2seq | |
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None | |
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None | |
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None | |
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None | |
last_hidden_state: Optional[jax_xla.DeviceArray] = None | |
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None | |
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None | |
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None | |