Transformers documentation

Trajectory Transformer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Trajectory Transformer


The Trajectory Transformer model was proposed in Offline Reinforcement Learning as One Big Sequence Modeling Problem by Michael Janner, Qiyang Li, Sergey Levine.

The abstract from the paper is the following:

Reinforcement learning (RL) is typically concerned with estimating stationary policies or single-step models, leveraging the Markov property to factorize problems in time. However, we can also view RL as a generic sequence modeling problem, with the goal being to produce a sequence of actions that leads to a sequence of high rewards. Viewed in this way, it is tempting to consider whether high-capacity sequence prediction models that work well in other domains, such as natural-language processing, can also provide effective solutions to the RL problem. To this end, we explore how RL can be tackled with the tools of sequence modeling, using a Transformer architecture to model distributions over trajectories and repurposing beam search as a planning algorithm. Framing RL as sequence modeling problem simplifies a range of design decisions, allowing us to dispense with many of the components common in offline RL algorithms. We demonstrate the flexibility of this approach across long-horizon dynamics prediction, imitation learning, goal-conditioned RL, and offline RL. Further, we show that this approach can be combined with existing model-free algorithms to yield a state-of-the-art planner in sparse-reward, long-horizon tasks.


This Transformer is used for deep reinforcement learning. To use it, you need to create sequences from actions, states and rewards from all previous timesteps. This model will treat all these elements together as one big sequence (a trajectory).

This model was contributed by CarlCochet. The original code can be found here.


class transformers.TrajectoryTransformerConfig

< >

( vocab_size = 100 batch_size = 256 action_weight = 5 reward_weight = 1 value_weight = 1 block_size = 249 action_dim = 6 observation_dim = 17 transition_dim = 25 n_layer = 4 n_head = 4 n_embd = 128 embd_pdrop = 0.1 attn_pdrop = 0.1 resid_pdrop = 0.1 learning_rate = 0.0006 max_position_embeddings = 512 type_vocab_size = 2 initializer_range = 0.02 layer_norm_eps = 1e-12 kaiming_initializer_range = 1 use_cache = True is_encoder_decoder = False pad_token_id = 1 bos_token_id = 50256 eos_token_id = 50256 **kwargs )


  • vocab_size (int, optional, defaults to 100) — Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be represented by the trajectories passed when calling TrajectoryTransformerModel
  • batch_size (int, optional, defaults to 256) — Size of the batch of trajectories passed to the model.
  • action_weight (int, optional, defaults to 5) — Weight of the action in the loss function
  • reward_weight (int, optional, defaults to 1) — Weight of the reward in the loss function
  • value_weight (int, optional, defaults to 1) — Weight of the value in the loss function
  • block_size (int, optional, defaults to 249) — Size of the blocks in the trajectory transformer.
  • action_dim (int, optional, defaults to 6) — Dimension of the action space.
  • observation_dim (int, optional, defaults to 17) — Dimension of the observation space.
  • transition_dim (int, optional, defaults to 25) — Dimension of the transition space.
  • n_layer (int, optional, defaults to 4) — Number of hidden layers in the Transformer encoder.
  • n_head (int, optional, defaults to 4) — Number of attention heads for each attention layer in the Transformer encoder.
  • n_embd (int, optional, defaults to 128) — Dimensionality of the embeddings and hidden states.
  • resid_pdrop (float, optional, defaults to 0.1) — The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  • embd_pdrop (int, optional, defaults to 0.1) — The dropout ratio for the embeddings.
  • attn_pdrop (float, optional, defaults to 0.1) — The dropout ratio for the attention.
  • hidden_act (str or function, optional, defaults to "gelu") — The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "selu" and "gelu_new" are supported.
  • max_position_embeddings (int, optional, defaults to 512) — The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
  • type_vocab_size (int, optional, defaults to 2) — The vocabulary size of the token_type_ids passed when calling TrajectoryTransformerModel
  • initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  • layer_norm_eps (float, optional, defaults to 1e-12) — The epsilon used by the layer normalization layers.
  • kaiming_initializer_range (`float, optional, defaults to 1) — A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers.
  • use_cache (bool, optional, defaults to True) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True. Example —

This is the configuration class to store the configuration of a TrajectoryTransformerModel. It is used to instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

>>> from transformers import TrajectoryTransformerModel, TrajectoryTransformerConfig

>>> # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
>>> configuration = TrajectoryTransformerConfig()

>>> # Initializing a model from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
>>> model = TrajectoryTransformerModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config


class transformers.TrajectoryTransformerModel

< >

( config )


  • config (TrajectoryTransformerConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

the full GPT language model, with a context size of block_size


< >

( trajectories: typing.Optional[torch.LongTensor] = None past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.Tensor]]] = None targets: typing.Optional[torch.FloatTensor] = None attention_mask: typing.Optional[torch.FloatTensor] = None use_cache: typing.Optional[bool] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) transformers.models.trajectory_transformer.modeling_trajectory_transformer.TrajectoryTransformerOutput or tuple(torch.FloatTensor)


  • trajectories (torch.LongTensor of shape (batch_size, sequence_length)) — Batch of trajectories, where a trajectory is a sequence of states, actions and rewards.
  • past_key_values (Tuple[Tuple[torch.Tensor]] of length config.n_layers, optional) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see past_key_values output below). Can be used to speed up sequential decoding. The input_ids which have their past given to this model should not be passed as input_ids as they have already been computed.
  • targets (torch.LongTensor of shape (batch_size, sequence_length), optional) — Desired targets used to compute the loss.
  • attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in [0, 1]:

    • 1 for tokens that are not masked,
    • 0 for tokens that are masked.

    What are attention masks?

  • use_cache (bool, optional) — If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).
  • output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.


transformers.models.trajectory_transformer.modeling_trajectory_transformer.TrajectoryTransformerOutput or tuple(torch.FloatTensor)

A transformers.models.trajectory_transformer.modeling_trajectory_transformer.TrajectoryTransformerOutput or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (TrajectoryTransformerConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) — Language modeling loss.
  • logits (torch.FloatTensor of shape (batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  • past_key_values (Tuple[Tuple[torch.Tensor]], optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of length config.n_layers, containing tuples of tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)). Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see past_key_values input) to speed up sequential decoding.
  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size). Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  • attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length). GPT2Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

The TrajectoryTransformerModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.


>>> from transformers import TrajectoryTransformerModel
>>> import torch

>>> model = TrajectoryTransformerModel.from_pretrained(
...     "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
... )
>>> model.eval()

>>> observations_dim, action_dim, batch_size = 17, 6, 256
>>> seq_length = observations_dim + action_dim + 1

>>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(
...     device
... )
>>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device)

>>> outputs = model(
...     trajectories,
...     targets=targets,
...     use_cache=True,
...     output_attentions=True,
...     output_hidden_states=True,
...     return_dict=True,
... )