winnie22
/
transformers_4_35_0
/models
/deprecated
/trajectory_transformer
/modeling_trajectory_transformer.py
# coding=utf-8 | |
# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch TrajectoryTransformer model.""" | |
import math | |
import os | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from torch.nn import functional as F | |
from ....modeling_utils import PreTrainedModel | |
from ....utils import ( | |
ModelOutput, | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
logging, | |
replace_return_docstrings, | |
) | |
from .configuration_trajectory_transformer import TrajectoryTransformerConfig | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "CarlCochet/trajectory-transformer-halfcheetah-medium-v2" | |
_CONFIG_FOR_DOC = "TrajectoryTransformerConfig" | |
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
"CarlCochet/trajectory-transformer-halfcheetah-medium-v2", | |
# See all TrajectoryTransformer models at https://huggingface.co/models?filter=trajectory_transformer | |
] | |
def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path): | |
"""Load tf checkpoints in a pytorch model.""" | |
try: | |
import re | |
import numpy as np | |
import tensorflow as tf | |
except ImportError: | |
logger.error( | |
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " | |
"https://www.tensorflow.org/install/ for installation instructions." | |
) | |
raise | |
tf_path = os.path.abspath(tf_checkpoint_path) | |
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") | |
# Load weights from TF model | |
init_vars = tf.train.list_variables(tf_path) | |
names = [] | |
arrays = [] | |
for name, shape in init_vars: | |
logger.info(f"Loading TF weight {name} with shape {shape}") | |
array = tf.train.load_variable(tf_path, name) | |
names.append(name) | |
arrays.append(array) | |
for name, array in zip(names, arrays): | |
name = name.split("/") | |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v | |
# which are not required for using pretrained model | |
if any( | |
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] | |
for n in name | |
): | |
logger.info(f"Skipping {'/'.join(name)}") | |
continue | |
pointer = model | |
for m_name in name: | |
if re.fullmatch(r"[A-Za-z]+_\d+", m_name): | |
scope_names = re.split(r"_(\d+)", m_name) | |
else: | |
scope_names = [m_name] | |
if scope_names[0] == "kernel" or scope_names[0] == "gamma": | |
pointer = getattr(pointer, "weight") | |
elif scope_names[0] == "output_bias" or scope_names[0] == "beta": | |
pointer = getattr(pointer, "bias") | |
elif scope_names[0] == "output_weights": | |
pointer = getattr(pointer, "weight") | |
elif scope_names[0] == "squad": | |
pointer = getattr(pointer, "classifier") | |
else: | |
try: | |
pointer = getattr(pointer, scope_names[0]) | |
except AttributeError: | |
logger.info(f"Skipping {'/'.join(name)}") | |
continue | |
if len(scope_names) >= 2: | |
num = int(scope_names[1]) | |
pointer = pointer[num] | |
if m_name[-11:] == "_embeddings": | |
pointer = getattr(pointer, "weight") | |
elif m_name == "kernel": | |
array = np.transpose(array) | |
try: | |
if pointer.shape != array.shape: | |
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") | |
except AssertionError as e: | |
e.args += (pointer.shape, array.shape) | |
raise | |
logger.info(f"Initialize PyTorch weight {name}") | |
pointer.data = torch.from_numpy(array) | |
return model | |
class TrajectoryTransformerOutput(ModelOutput): | |
""" | |
Base class for model's outputs that also contains a pooling of the last hidden states. | |
Args: | |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
Language modeling loss. | |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, | |
sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the | |
attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of | |
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer | |
plus the initial embedding outputs. | |
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. GPT2Attentions weights after the attention softmax, used to compute the weighted average | |
in the self-attention heads. | |
""" | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
class TrajectoryTransformerPreTrainedModel(PreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = TrajectoryTransformerConfig | |
load_tf_weights = load_tf_weights_in_trajectory_transformer | |
base_model_prefix = "trajectory_transformer" | |
main_input_name = "trajectories" | |
supports_gradient_checkpointing = True | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, TrajectoryTransformerModel): | |
module.gradient_checkpointing = value | |
def _init_weights(self, module): | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
elif isinstance(module, EinLinear): | |
for i in range(module.n_models): | |
nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range) | |
if module.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i]) | |
bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range | |
nn.init.uniform_(module.bias[i], -bound, bound) | |
TRAJECTORY_TRANSFORMER_START_DOCSTRING = r""" | |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use | |
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and | |
behavior. | |
Parameters: | |
config ([`TrajectoryTransformerConfig`]): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING = r""" | |
Args: | |
trajectories (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Batch of trajectories, where a trajectory is a sequence of states, actions and rewards. | |
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`, *optional*): | |
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see | |
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have | |
their past given to this model should not be passed as `input_ids` as they have already been computed. | |
targets (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Desired targets used to compute the loss. | |
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
class EinLinear(nn.Module): | |
def __init__(self, n_models, in_features, out_features, bias): | |
super().__init__() | |
self.n_models = n_models | |
self.out_features = out_features | |
self.in_features = in_features | |
self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features)) | |
if bias: | |
self.bias = nn.Parameter(torch.Tensor(n_models, out_features)) | |
else: | |
self.register_parameter("bias", None) | |
def reset_parameters(self): | |
for i in range(self.n_models): | |
nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) | |
if self.bias is not None: | |
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i]) | |
bound = 1 / math.sqrt(fan_in) | |
nn.init.uniform_(self.bias[i], -bound, bound) | |
def forward(self, input): | |
""" | |
Args: | |
input (`torch.FloatTensor` of shape `(B, n_models, input_dim)`): | |
The input to the layer. | |
""" | |
# [ batch_size x n_models x output_dim ] | |
output = torch.einsum("eoi,bei->beo", self.weight, input) | |
if self.bias is not None: | |
raise RuntimeError() | |
return output | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
if config.n_embd % config.n_head != 0: | |
raise ValueError(f"n_head ({config.n_head}) should be a divisor of n_embd ({config.n_embd})") | |
# key, query, value projections for all heads | |
self.key = nn.Linear(config.n_embd, config.n_embd) | |
self.query = nn.Linear(config.n_embd, config.n_embd) | |
self.value = nn.Linear(config.n_embd, config.n_embd) | |
# regularization | |
self.attn_drop = nn.Dropout(config.attn_pdrop) | |
self.resid_drop = nn.Dropout(config.resid_pdrop) | |
# output projection | |
self.proj = nn.Linear(config.n_embd, config.n_embd) | |
# causal mask to ensure that attention is only applied to the left in the input sequence | |
self.register_buffer( | |
"mask", | |
torch.tril(torch.ones(config.block_size, config.block_size)).view( | |
1, 1, config.block_size, config.block_size | |
), | |
persistent=False, | |
) | |
# mask previous value estimates | |
joined_dim = config.observation_dim + config.action_dim + 2 | |
self.mask.squeeze()[:, joined_dim - 1 :: joined_dim] = 0 | |
self.n_head = config.n_head | |
def forward( | |
self, | |
hidden_states: Optional[Tuple[torch.FloatTensor]], | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
use_cache: Optional[bool] = False, | |
output_attentions: Optional[bool] = False, | |
): | |
batch_size, sequence_length, embedding_dim = hidden_states.size() | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
# [ batch_size x n_heads x sequence_length x head_dim ] | |
key = ( | |
self.key(hidden_states) | |
.view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) | |
.transpose(1, 2) | |
) | |
query = ( | |
self.query(hidden_states) | |
.view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) | |
.transpose(1, 2) | |
) | |
value = ( | |
self.value(hidden_states) | |
.view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) | |
.transpose(1, 2) | |
) | |
if layer_past is not None: | |
past_key, past_value = layer_past | |
key = torch.cat((past_key, key), dim=-2) | |
value = torch.cat((past_value, value), dim=-2) | |
if use_cache is True: | |
present = (key, value) | |
else: | |
present = None | |
# causal self-attention | |
# [ batch_size x n_heads x sequence_length x sequence_length ] | |
attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1))) | |
attn_weights = attn_weights.masked_fill( | |
self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min | |
) | |
attn_weights = F.softmax(attn_weights, dim=-1) | |
self._attn_map = attn_weights.clone() | |
attn_weights = self.attn_drop(attn_weights) | |
output = torch.matmul(attn_weights, value) | |
# [ batch_size x sequence_length x embedding_dim ] | |
# re-assemble all head outputs side by side | |
output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, embedding_dim) | |
# output projection | |
output = self.resid_drop(self.proj(output)) | |
outputs = (output, present) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
class Block(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.ln1 = nn.LayerNorm(config.n_embd) | |
self.ln2 = nn.LayerNorm(config.n_embd) | |
self.attn = CausalSelfAttention(config) | |
# MLP | |
self.l1 = nn.Linear(config.n_embd, 4 * config.n_embd) | |
self.act = nn.GELU() | |
self.l2 = nn.Linear(4 * config.n_embd, config.n_embd) | |
self.drop = nn.Dropout(config.resid_pdrop) | |
def forward( | |
self, | |
hidden_states: Optional[Tuple[torch.FloatTensor]], | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
use_cache: Optional[bool] = False, | |
output_attentions: Optional[bool] = False, | |
): | |
residual = hidden_states | |
hidden_states = self.ln1(hidden_states) | |
attn_outputs = self.attn( | |
hidden_states, layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions | |
) | |
attn_output = attn_outputs[0] | |
outputs = attn_outputs[1:] | |
hidden_states = attn_output + residual | |
residual = hidden_states | |
hidden_states = self.ln2(hidden_states) | |
hidden_states = self.l1(hidden_states) | |
hidden_states = self.act(hidden_states) | |
hidden_states = self.l2(hidden_states) | |
hidden_states = residual + self.drop(hidden_states) | |
if use_cache: | |
outputs = (hidden_states,) + outputs | |
else: | |
outputs = (hidden_states,) + outputs[1:] | |
return outputs | |
class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): | |
"""the full GPT language model, with a context size of block_size""" | |
def __init__(self, config): | |
super().__init__(config) | |
# input embedding stem (+1 for stop token) | |
self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd) | |
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) | |
self.drop = nn.Dropout(config.embd_pdrop) | |
# transformer | |
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) | |
# decoder head | |
self.ln_f = nn.LayerNorm(config.n_embd) | |
self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False) | |
self.vocab_size = config.vocab_size | |
self.stop_token = config.vocab_size * config.transition_dim | |
self.block_size = config.block_size | |
self.observation_dim = config.observation_dim | |
self.action_dim = config.action_dim | |
self.transition_dim = config.transition_dim | |
self.embedding_dim = config.n_embd | |
self.action_weight = config.action_weight | |
self.reward_weight = config.reward_weight | |
self.value_weight = config.value_weight | |
self.gradient_checkpointing = False | |
self.post_init() | |
def get_block_size(self): | |
return self.block_size | |
def offset_tokens(self, trajectories): | |
_, sequence_length = trajectories.shape | |
n_states = int(np.ceil(sequence_length / self.transition_dim)) | |
offsets = torch.arange(self.transition_dim) * self.vocab_size | |
offsets = offsets.repeat(n_states).to(trajectories.device) | |
offset_trajectories = trajectories + offsets[:sequence_length] | |
offset_trajectories[trajectories == self.vocab_size] = self.stop_token | |
return offset_trajectories | |
def pad_to_full_observation(self, hidden_states): | |
batch_size, sequence_length, _ = hidden_states.shape | |
n_pad = (self.transition_dim - sequence_length % self.transition_dim) % self.transition_dim | |
padding = torch.zeros(batch_size, n_pad, self.embedding_dim, device=hidden_states.device) | |
# [ batch_size x padded_sequence_length' x embedding_dim ] | |
hidden_states_pad = torch.cat([hidden_states, padding], dim=1) | |
hidden_states_pad = hidden_states_pad.view(-1, self.transition_dim, self.embedding_dim) | |
return hidden_states_pad, n_pad | |
def forward( | |
self, | |
trajectories: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
targets: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.Tensor], TrajectoryTransformerOutput]: | |
r""" | |
Returns: | |
Examples: | |
```python | |
>>> from transformers import TrajectoryTransformerModel | |
>>> import torch | |
>>> model = TrajectoryTransformerModel.from_pretrained( | |
... "CarlCochet/trajectory-transformer-halfcheetah-medium-v2" | |
... ) | |
>>> model.to(device) | |
>>> model.eval() | |
>>> observations_dim, action_dim, batch_size = 17, 6, 256 | |
>>> seq_length = observations_dim + action_dim + 1 | |
>>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to( | |
... device | |
... ) | |
>>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device) | |
>>> outputs = model( | |
... trajectories, | |
... targets=targets, | |
... use_cache=True, | |
... output_attentions=True, | |
... output_hidden_states=True, | |
... return_dict=True, | |
... ) | |
``` | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
if past_key_values is None: | |
past_key_values = tuple([None] * len(self.blocks)) | |
batch_size, sequence_length = trajectories.size() | |
if sequence_length > self.block_size: | |
raise ValueError("Cannot forward, model block size is exhausted.") | |
offset_trajectories = self.offset_tokens(trajectories) | |
# [ batch_size x sequence_length x embedding_dim ] | |
# forward the GPT model | |
token_embeddings = self.tok_emb(offset_trajectories) # each index maps to a (learnable) vector | |
position_embeddings = self.pos_emb[:, :sequence_length, :] # each position maps to a (learnable) vector | |
hidden_states = self.drop(token_embeddings + position_embeddings) | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
logger.warning_once( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
) | |
use_cache = False | |
presents = () if use_cache else None | |
all_self_attentions = () if output_attentions else None | |
all_hidden_states = () if output_hidden_states else None | |
for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(block), | |
hidden_states, | |
layer_past, | |
use_cache, | |
output_attentions, | |
) | |
else: | |
outputs = block(hidden_states, layer_past, use_cache, output_attentions) | |
hidden_states = outputs[0] | |
if use_cache is True: | |
presents = presents + (outputs[1],) | |
if output_attentions: | |
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) | |
# [ batch_size x sequence_length x embedding_dim ] | |
hidden_state = self.ln_f(hidden_states) | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
hidden_states_pad, n_pad = self.pad_to_full_observation(hidden_state) | |
logits = self.head(hidden_states_pad) | |
logits = logits.reshape(batch_size, sequence_length + n_pad, self.vocab_size + 1) | |
logits = logits[:, :sequence_length] | |
# if we are given some desired targets also calculate the loss | |
if targets is not None: | |
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction="none") | |
if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1: | |
# make weights | |
n_states = int(np.ceil(sequence_length / self.transition_dim)) | |
weights = torch.cat( | |
[ | |
torch.ones(self.observation_dim, device=trajectories.device), | |
torch.ones(self.action_dim, device=trajectories.device) * self.action_weight, | |
torch.ones(1, device=trajectories.device) * self.reward_weight, | |
torch.ones(1, device=trajectories.device) * self.value_weight, | |
] | |
) | |
weights = weights.repeat(n_states) | |
weights = weights[1:].repeat(batch_size, 1) | |
loss = loss * weights.view(-1) | |
loss = (loss * attention_mask.view(-1)).mean() | |
else: | |
loss = None | |
if not return_dict: | |
return tuple(v for v in [loss, logits, presents, all_hidden_states, all_self_attentions] if v is not None) | |
return TrajectoryTransformerOutput( | |
loss=loss, | |
logits=logits, | |
past_key_values=presents, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attentions, | |
) | |