The Trajectory Transformer model was proposed in Offline Reinforcement Learning as One Big Sequence Modeling Problem by Michael Janner, Qiyang Li, Sergey Levine.
The abstract from the paper is the following:
Reinforcement learning (RL) is typically concerned with estimating stationary policies or single-step models, leveraging the Markov property to factorize problems in time. However, we can also view RL as a generic sequence modeling problem, with the goal being to produce a sequence of actions that leads to a sequence of high rewards. Viewed in this way, it is tempting to consider whether high-capacity sequence prediction models that work well in other domains, such as natural-language processing, can also provide effective solutions to the RL problem. To this end, we explore how RL can be tackled with the tools of sequence modeling, using a Transformer architecture to model distributions over trajectories and repurposing beam search as a planning algorithm. Framing RL as sequence modeling problem simplifies a range of design decisions, allowing us to dispense with many of the components common in offline RL algorithms. We demonstrate the flexibility of this approach across long-horizon dynamics prediction, imitation learning, goal-conditioned RL, and offline RL. Further, we show that this approach can be combined with existing model-free algorithms to yield a state-of-the-art planner in sparse-reward, long-horizon tasks.
This Transformer is used for deep reinforcement learning. To use it, you need to create sequences from actions, states and rewards from all previous timesteps. This model will treat all these elements together as one big sequence (a trajectory).
class transformers.TrajectoryTransformerConfig< source >
( vocab_size = 100 action_weight = 5 reward_weight = 1 value_weight = 1 block_size = 249 action_dim = 6 observation_dim = 17 transition_dim = 25 n_layer = 4 n_head = 4 n_embd = 128 embd_pdrop = 0.1 attn_pdrop = 0.1 resid_pdrop = 0.1 learning_rate = 0.0006 max_position_embeddings = 512 type_vocab_size = 2 initializer_range = 0.02 layer_norm_eps = 1e-12 kaiming_initializer_range = 1 use_cache = True pad_token_id = 1 bos_token_id = 50256 eos_token_id = 50256 **kwargs )
int, optional, defaults to 100) — Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be represented by the
trajectoriespassed when calling TrajectoryTransformerModel
int, optional, defaults to 5) — Weight of the action in the loss function
int, optional, defaults to 1) — Weight of the reward in the loss function
int, optional, defaults to 1) — Weight of the value in the loss function
int, optional, defaults to 249) — Size of the blocks in the trajectory transformer.
int, optional, defaults to 6) — Dimension of the action space.
int, optional, defaults to 17) — Dimension of the observation space.
int, optional, defaults to 25) — Dimension of the transition space.
int, optional, defaults to 4) — Number of hidden layers in the Transformer encoder.
int, optional, defaults to 4) — Number of attention heads for each attention layer in the Transformer encoder.
int, optional, defaults to 128) — Dimensionality of the embeddings and hidden states.
float, optional, defaults to 0.1) — The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
int, optional, defaults to 0.1) — The dropout ratio for the embeddings.
float, optional, defaults to 0.1) — The dropout ratio for the attention.
- hidden_act (
function, optional, defaults to
"gelu") — The non-linear activation function (function or string) in the encoder and pooler. If string,
int, optional, defaults to 512) — The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
int, optional, defaults to 2) — The vocabulary size of the
token_type_idspassed when calling TrajectoryTransformerModel
float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
float, optional, defaults to 1e-12) — The epsilon used by the layer normalization layers.
- kaiming_initializer_range (`float, optional, defaults to 1) — A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers.
bool, optional, defaults to
True) — Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if
config.is_decoder=True. Example —
This is the configuration class to store the configuration of a TrajectoryTransformerModel. It is used to instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 architecture.
from transformers import TrajectoryTransformerConfig, TrajectoryTransformerModel # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration configuration = TrajectoryTransformerConfig() # Initializing a model (with random weights) from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration model = TrajectoryTransformerModel(configuration) # Accessing the model configuration configuration = model.config
class transformers.TrajectoryTransformerModel< source >
( config )
The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
the full GPT language model, with a context size of block_size
forward< source >
trajectories: typing.Optional[torch.LongTensor] = None
past_key_values: typing.Optional[typing.Tuple[typing.Tuple[torch.Tensor]]] = None
targets: typing.Optional[torch.FloatTensor] = None
attention_mask: typing.Optional[torch.FloatTensor] = None
use_cache: typing.Optional[bool] = None
output_attentions: typing.Optional[bool] = None
output_hidden_states: typing.Optional[bool] = None
return_dict: typing.Optional[bool] = None
(batch_size, sequence_length)) — Batch of trajectories, where a trajectory is a sequence of states, actions and rewards.
config.n_layers, optional) — Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
past_key_valuesoutput below). Can be used to speed up sequential decoding. The
input_idswhich have their past given to this model should not be passed as
input_idsas they have already been computed.
(batch_size, sequence_length), optional) — Desired targets used to compute the loss.
(batch_size, sequence_length), optional) — Mask to avoid performing attention on padding token indices. Mask values selected in
- 1 for tokens that are not masked,
- 0 for tokens that are masked.
bool, optional) — If set to
past_key_valueskey value states are returned and can be used to speed up decoding (see
bool, optional) — Whether or not to return the attentions tensors of all attention layers. See
attentionsunder returned tensors for more detail.
- output_hidden_states (
bool, optional) — Whether or not to return the hidden states of all layers. See
hidden_statesunder returned tensors for more detail.
bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
transformers.models.trajectory_transformer.modeling_trajectory_transformer.TrajectoryTransformerOutput or a tuple of
return_dict=False is passed or when
config.return_dict=False) comprising various
elements depending on the configuration (TrajectoryTransformerConfig) and inputs.
- loss (
(1,), optional, returned when
labelsis provided) — Language modeling loss.
- logits (
(batch_size, sequence_length, config.vocab_size)) — Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (
Tuple[Tuple[torch.Tensor]], optional, returned when
use_cache=Trueis passed or when
config.use_cache=True) — Tuple of length
config.n_layers, containing tuples of tensors of shape
(batch_size, num_heads, sequence_length, embed_size_per_head)). Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
past_key_valuesinput) to speed up sequential decoding.
- hidden_states (
tuple(torch.FloatTensor), optional, returned when
output_hidden_states=Trueis passed or when
config.output_hidden_states=True) — Tuple of
torch.FloatTensor(one for the output of the embeddings + one for the output of each layer) of shape
(batch_size, sequence_length, hidden_size). Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (
tuple(torch.FloatTensor), optional, returned when
output_attentions=Trueis passed or when
config.output_attentions=True) — Tuple of
torch.FloatTensor(one for each layer) of shape
(batch_size, num_heads, sequence_length, sequence_length). GPT2Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The TrajectoryTransformerModel forward method, overrides the
__call__ special method.
Although the recipe for forward pass needs to be defined within this function, one should call the
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
from transformers import TrajectoryTransformerModel import torch model = TrajectoryTransformerModel.from_pretrained( "CarlCochet/trajectory-transformer-halfcheetah-medium-v2" ) model.to(device) model.eval() observations_dim, action_dim, batch_size = 17, 6, 256 seq_length = observations_dim + action_dim + 1 trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to( device ) targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device) outputs = model( trajectories, targets=targets, use_cache=True, output_attentions=True, output_hidden_states=True, return_dict=True, )