|
import warnings |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from gymnasium import spaces |
|
from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn |
|
from transformers import GPTNeoModel, GPTNeoPreTrainedModel |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.models.vit.modeling_vit import ViTPatchEmbeddings |
|
|
|
from .configuration_jat import JatConfig |
|
from .processing_jat import JatProcessor |
|
|
|
|
|
def compute_mse_loss( |
|
predicted: FloatTensor, true: FloatTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None |
|
) -> FloatTensor: |
|
""" |
|
Compute the Mean Squared Error (MSE) loss between predicted and true observations, considering valid timesteps. |
|
|
|
Args: |
|
predicted (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`): |
|
Predicted observations at the output of the model. |
|
true (`FloatTensor` of shape `(batch_size, max_seq_len, ...)`): |
|
Ground truth observations. |
|
mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*): |
|
Boolean mask indicating valid timesteps. |
|
weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*): |
|
Weights to be applied to the loss. |
|
|
|
Returns: |
|
loss (`FloatTensor` of shape `(,)`): |
|
MSE loss between predicted and true observations. |
|
""" |
|
|
|
loss = F.mse_loss(predicted, true, reduction="none") |
|
|
|
|
|
for dim in reversed(range(2, loss.dim())): |
|
loss = loss.mean(dim=dim) |
|
|
|
|
|
if mask is not None: |
|
loss = loss * mask |
|
|
|
|
|
if weights is not None: |
|
loss = loss * weights |
|
|
|
|
|
loss = loss.sum() / mask.sum() if mask is not None else loss.mean() |
|
|
|
return loss |
|
|
|
|
|
def compute_ce_loss( |
|
logits: FloatTensor, labels: torch.LongTensor, mask: Optional[BoolTensor], weights: Optional[FloatTensor] = None |
|
) -> FloatTensor: |
|
""" |
|
Compute the Cross Entropy (CE) loss between predicted logits and true class labels, considering valid timesteps. |
|
|
|
Args: |
|
logits (`FloatTensor` of shape `(batch_size, max_seq_len, [inner_size,] num_classes)`): |
|
Predicted logits at the output of the model. |
|
labels (`torch.LongTensor` of shape `(batch_size, max_seq_len, [inner_size,])`): |
|
Ground truth class labels. |
|
mask (`BoolTensor` of shape `(batch_size, max_seq_len)`, *optional*): |
|
Boolean mask indicating valid timesteps. |
|
weights (`FloatTensor` of shape `(batch_size, max_seq_len)`, *optional*): |
|
Weights to be applied to the loss. |
|
|
|
Returns: |
|
loss (`FloatTensor` of shape `(,)`): |
|
CE loss between predicted logits and true class labels. |
|
""" |
|
if mask is not None: |
|
logits = logits[mask.bool()] |
|
labels = labels[mask.bool()] |
|
if weights is not None: |
|
weights = weights[mask.bool()] |
|
else: |
|
logits = logits.flatten(end_dim=2) |
|
labels = labels.flatten(end_dim=1) |
|
if weights is not None: |
|
weights = weights.flatten(end_dim=1) |
|
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), reduction="none") |
|
loss = loss.view(labels.size()) |
|
loss = loss.mean(-1) |
|
|
|
|
|
if weights is not None: |
|
loss = loss * weights |
|
|
|
|
|
loss = loss.mean() |
|
|
|
return loss |
|
|
|
|
|
def cyclic_expand_dim(tensor: Tensor, expanded_dim_size: int) -> Tensor: |
|
""" |
|
Expands the last dimension of a tensor cyclically to a specified size. |
|
|
|
Args: |
|
tensor (`torch.Tensor` of shape `(batch_size, seq_len, ...)`): |
|
Input tensor whose last dimension is to be expanded cyclically. |
|
expanded_dim_size (`int`): |
|
The desired size of the last dimension after expansion. |
|
|
|
Returns: |
|
`torch.Tensor` of shape `(batch_size, seq_len, expanded_dim_size)`: |
|
A tensor with its last dimension expanded cyclically to the specified size. |
|
|
|
Examples: |
|
>>> tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) |
|
>>> cyclic_expand_dim(tensor, 5) |
|
tensor([[[1, 2, 1, 2, 1], [3, 4, 3, 4, 3]], [[5, 6, 5, 6, 5], [7, 8, 7, 8, 7]]]) |
|
""" |
|
B, L, X = tensor.shape |
|
if expanded_dim_size < X: |
|
raise ValueError( |
|
f"Expanded dimension size ({expanded_dim_size}) must be greater than the original dimension size ({X})." |
|
) |
|
indices = torch.arange(expanded_dim_size) % X |
|
return tensor[..., indices] |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
""" |
|
A residual block module that consists of two convolutional layers with a residual connection. |
|
|
|
Args: |
|
in_shape (`Tuple[int, int, int]`): |
|
Shape of the input tensor. |
|
out_channels (`int`): |
|
Number of output channels. |
|
|
|
Returns: |
|
`torch.Tensor` of shape `(batch_size, out_channels, in_shape[1], in_shape[2])`: |
|
Output tensor. |
|
""" |
|
|
|
def __init__(self, in_shape: Tuple[int, int, int], out_channels: int) -> None: |
|
super().__init__() |
|
out_shape = (out_channels, in_shape[1], in_shape[2]) |
|
|
|
self.conv1 = nn.Conv2d(in_shape[0], out_channels, kernel_size=3, stride=1, padding=1) |
|
self.norm1 = nn.LayerNorm(out_shape) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
self.norm2 = nn.LayerNorm(out_shape) |
|
|
|
|
|
self.shortcut = nn.Sequential( |
|
nn.Conv2d(in_shape[0], out_channels, kernel_size=1, stride=1), nn.LayerNorm(out_shape) |
|
) |
|
|
|
def forward(self, x: FloatTensor) -> FloatTensor: |
|
out = F.leaky_relu(self.norm1(self.conv1(x))) |
|
out = self.norm2(self.conv2(out)) |
|
out += self.shortcut(x) |
|
return F.leaky_relu(out, inplace=True) |
|
|
|
|
|
class AttentionLayer(nn.Module): |
|
""" |
|
Attention layer that applies an attention mechanism to the input tensor. |
|
|
|
Args: |
|
num_channels (`int`): |
|
Number of channels. |
|
|
|
Returns: |
|
`torch.Tensor`: |
|
Output tensor of the same shape as the input tensor. |
|
""" |
|
|
|
def __init__(self, num_channels: int) -> None: |
|
super().__init__() |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.fc = nn.Sequential( |
|
nn.Linear(num_channels, num_channels // 8, bias=False), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(num_channels // 8, num_channels, bias=False), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, x: FloatTensor) -> FloatTensor: |
|
b, c, _, _ = x.size() |
|
y = self.avg_pool(x).view(b, c) |
|
y = self.fc(y).view(b, c, 1, 1) |
|
return x * y.expand_as(x) |
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
""" |
|
Image encoder that encodes a batch of images. |
|
|
|
Args: |
|
hidden_size (`int`): |
|
Size of the output hidden state. |
|
|
|
Returns: |
|
`torch.Tensor` of shape `(batch_size, hidden_size)`: |
|
Output tensor. |
|
""" |
|
|
|
def __init__(self, hidden_size: int) -> None: |
|
super().__init__() |
|
self.conv1 = nn.Conv2d(4, 32, kernel_size=3, stride=2, padding=1) |
|
self.norm1 = nn.InstanceNorm2d(32) |
|
self.att1 = AttentionLayer(32) |
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) |
|
self.norm2 = nn.InstanceNorm2d(64) |
|
self.att2 = AttentionLayer(64) |
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1) |
|
self.norm3 = nn.InstanceNorm2d(128) |
|
self.att3 = AttentionLayer(128) |
|
self.fc = nn.Linear(128 * 11 * 11, hidden_size) |
|
|
|
def forward(self, x: FloatTensor) -> FloatTensor: |
|
x = F.leaky_relu(self.norm1(self.conv1(x)), inplace=True) |
|
x = self.att1(x) |
|
x = F.leaky_relu(self.norm2(self.conv2(x)), inplace=True) |
|
x = self.att2(x) |
|
x = F.leaky_relu(self.norm3(self.conv3(x)), inplace=True) |
|
x = self.att3(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
class ImageDecoder(nn.Module): |
|
""" |
|
Image decoder that decodes a batch of encoded representations. |
|
|
|
Args: |
|
hidden_size (`int`): |
|
Size of the input hidden state. |
|
|
|
Returns: |
|
`torch.Tensor` of shape `(batch_size, 4, 84, 84)`: |
|
Output tensor representing the reconstructed images. |
|
""" |
|
|
|
def __init__(self, hidden_size: int) -> None: |
|
super().__init__() |
|
self.fc = nn.Linear(hidden_size, 128 * 11 * 11) |
|
self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
self.norm1 = nn.InstanceNorm2d(64) |
|
self.att1 = AttentionLayer(64) |
|
self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
self.norm2 = nn.InstanceNorm2d(32) |
|
self.att2 = AttentionLayer(32) |
|
self.deconv3 = nn.ConvTranspose2d(32, 4, kernel_size=3, stride=2, padding=1, output_padding=1) |
|
|
|
def forward(self, x: FloatTensor) -> FloatTensor: |
|
x = self.fc(x) |
|
x = x.view(x.size(0), 128, 11, 11) |
|
x = F.leaky_relu(self.norm1(self.deconv1(x)), inplace=True) |
|
x = F.interpolate(x, size=(21, 21)) |
|
x = self.att1(x) |
|
x = F.leaky_relu(self.norm2(self.deconv2(x)), inplace=True) |
|
x = self.att2(x) |
|
x = F.tanh(self.deconv3(x)) |
|
return x |
|
|
|
|
|
class DualBatchReshapeWrapper(nn.Module): |
|
""" |
|
Wrapper to make a module designed for a single batch work with a dual batch. |
|
|
|
Args: |
|
module (`nn.Module`): |
|
Module to be wrapped. |
|
""" |
|
|
|
def __init__(self, module: nn.Module) -> None: |
|
super().__init__() |
|
self.module = module |
|
|
|
def forward(self, x: FloatTensor) -> FloatTensor: |
|
n1, n2 = x.shape[:2] |
|
x = x.view(n1 * n2, *x.shape[2:]) |
|
x = self.module(x) |
|
x = x.view(n1, n2, *x.shape[1:]) |
|
return x |
|
|
|
|
|
@dataclass |
|
class JatOutput(ModelOutput): |
|
""" |
|
Output of the Jat model. |
|
|
|
The model can be used for both RL and NLP tasks. For RL tasks, the model takes in observations and actions |
|
(`continuous_observations`, `discrete_actions`, etc.). For textual tasks, the model takes in a sequence of tokens |
|
and/or images (`input_ids`, `image`). The output depends on the type of input. |
|
|
|
Args: |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): |
|
For RL input, the loss is the sum of the observation loss and the action loss. |
|
For textual input, the causal language modeling loss. |
|
observation_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): |
|
Only returned when RL input is provided. The MSE loss between predicted and true observations for |
|
continuous observations and the cross-entropy loss for discrete observations. |
|
action_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): |
|
Only returned when RL input is provided. The MSE loss between predicted and true actions for |
|
continuous actions and the cross-entropy loss for discrete actions. |
|
pred_observations (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`): |
|
Only returned when RL input is provided. Predicted observations from t=1 to t=max_seq_len+1. |
|
pred_actions (`torch.FloatTensor` of shape `(batch_size, max_seq_len, ...)`): |
|
Only returned when RL input is provided. Predicted actions from t=0 to t=max_seq_len. When input actions |
|
are discrete, the predicted actions are logits. |
|
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
Sequence of hidden-states at the output of the last layer of the model. |
|
|
|
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, |
|
hidden_size)` is output. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or |
|
when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if |
|
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, |
|
encoder_sequence_length, embed_size_per_head)`. |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if |
|
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` |
|
input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or |
|
when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when |
|
`config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
loss: Optional[FloatTensor] = None |
|
observation_loss: Optional[FloatTensor] = None |
|
action_loss: Optional[FloatTensor] = None |
|
pred_observations: Optional[FloatTensor] = None |
|
pred_actions: Optional[FloatTensor] = None |
|
logits: Optional[FloatTensor] = None |
|
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[FloatTensor]] = None |
|
attentions: Optional[Tuple[FloatTensor]] = None |
|
|
|
|
|
class JatModel(GPTNeoPreTrainedModel): |
|
""" |
|
Jat model. |
|
""" |
|
|
|
config_class = JatConfig |
|
|
|
def __init__(self, config: JatConfig) -> None: |
|
super().__init__(config) |
|
|
|
vocab_size = config.vocab_size |
|
hidden_size = config.hidden_size |
|
max_discrete_value = config.max_discrete_value |
|
max_continuous_size = config.max_continuous_size |
|
self.observation_loss_coef = config.observation_loss_coef |
|
self.action_loss_coef = config.action_loss_coef |
|
|
|
|
|
self.transformer = GPTNeoModel(config) |
|
|
|
|
|
self.vit_encoder = ViTPatchEmbeddings(config) |
|
self.single_discrete_encoder = self.transformer.wte |
|
self.continuous_encoder = nn.Linear(max_continuous_size, hidden_size) |
|
self.multi_discrete_encoder = nn.Sequential( |
|
self.single_discrete_encoder, |
|
nn.Linear(hidden_size, hidden_size // 50), |
|
nn.ReLU(), |
|
nn.Flatten(start_dim=2), |
|
nn.Linear(max_discrete_value * (hidden_size // 50), hidden_size - 1), |
|
) |
|
self.image_encoder = DualBatchReshapeWrapper(ImageEncoder(hidden_size)) |
|
|
|
|
|
self.single_discrete_decoder = nn.Linear(hidden_size, vocab_size, bias=False) |
|
self.continuous_decoder = nn.Linear(hidden_size, max_continuous_size) |
|
self.multi_discrete_decoder = nn.Sequential( |
|
nn.Linear(hidden_size, max_discrete_value * (hidden_size // 50)), |
|
nn.Unflatten(dim=2, unflattened_size=(max_discrete_value, hidden_size // 50)), |
|
nn.ReLU(), |
|
nn.Linear(hidden_size // 50, hidden_size), |
|
nn.ReLU(), |
|
nn.Linear(hidden_size, 8, bias=False), |
|
) |
|
self.image_decoder = DualBatchReshapeWrapper(ImageDecoder(hidden_size)) |
|
|
|
|
|
self.post_init() |
|
|
|
def embed_textual( |
|
self, |
|
input_ids: Optional[LongTensor], |
|
pixel_values: Optional[FloatTensor] = None, |
|
attention_mask: Optional[BoolTensor] = None, |
|
) -> Tensor: |
|
text_inputs_embeds = self.single_discrete_encoder(input_ids) if input_ids is not None else None |
|
image_inputs_embeds = self.vit_encoder(pixel_values) if pixel_values is not None else None |
|
|
|
if image_inputs_embeds is not None and text_inputs_embeds is not None: |
|
inputs_embeds = torch.cat((image_inputs_embeds, text_inputs_embeds), dim=1) |
|
|
|
image_mask = torch.ones(image_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device) |
|
if attention_mask is None: |
|
attention_mask = torch.ones(text_inputs_embeds.shape[:2], dtype=torch.bool, device=self.device) |
|
attention_mask = torch.cat((image_mask, attention_mask), dim=1) |
|
elif image_inputs_embeds is not None: |
|
inputs_embeds = image_inputs_embeds |
|
elif text_inputs_embeds is not None: |
|
inputs_embeds = text_inputs_embeds |
|
attention_mask = attention_mask |
|
else: |
|
raise ValueError("At least one of `input_ids` or `pixel_values` must be provided.") |
|
return inputs_embeds, attention_mask |
|
|
|
def embed_rl( |
|
self, |
|
continuous_observations: Optional[FloatTensor] = None, |
|
discrete_observations: Optional[LongTensor] = None, |
|
image_observations: Optional[FloatTensor] = None, |
|
continuous_actions: Optional[FloatTensor] = None, |
|
discrete_actions: Optional[LongTensor] = None, |
|
rewards: Optional[FloatTensor] = None, |
|
attention_mask: Optional[BoolTensor] = None, |
|
): |
|
|
|
assert rewards is not None |
|
if continuous_observations is not None: |
|
continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1) |
|
continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size) |
|
if continuous_actions is not None: |
|
continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size) |
|
|
|
|
|
if continuous_observations is not None: |
|
batch_size, seq_len = continuous_observations.shape[:2] |
|
inputs_embeds_observations = self.continuous_encoder(continuous_observations) |
|
elif discrete_observations is not None: |
|
batch_size, seq_len = discrete_observations.shape[:2] |
|
inputs_embeds_observations = self.multi_discrete_encoder(discrete_observations) |
|
inputs_embeds_observations = torch.cat((inputs_embeds_observations, rewards.unsqueeze(-1)), dim=-1) |
|
elif image_observations is not None: |
|
batch_size, seq_len = image_observations.shape[:2] |
|
inputs_embeds_observations = self.image_encoder(image_observations) |
|
else: |
|
raise ValueError("Missing observations.") |
|
if continuous_actions is not None: |
|
inputs_embeds_actions = self.continuous_encoder(continuous_actions) |
|
elif discrete_actions is not None: |
|
inputs_embeds_actions = self.single_discrete_encoder(discrete_actions) |
|
else: |
|
raise ValueError("Missing actions.") |
|
|
|
|
|
inputs_embeds = torch.cat((inputs_embeds_observations, inputs_embeds_actions), dim=2) |
|
inputs_embeds = inputs_embeds.view(batch_size, 2 * seq_len, self.config.hidden_size) |
|
if attention_mask is not None: |
|
attention_mask = torch.repeat_interleave(attention_mask, repeats=2, dim=1) |
|
return inputs_embeds, attention_mask |
|
|
|
def output_textual( |
|
self, |
|
transformer_outputs, |
|
input_ids: Optional[LongTensor] = None, |
|
attention_mask: Optional[BoolTensor] = None, |
|
return_loss: bool = True, |
|
return_dict: Optional[bool] = None, |
|
): |
|
hidden_states = transformer_outputs[0] |
|
loss = None |
|
|
|
lm_logits = self.single_discrete_decoder(hidden_states) |
|
if return_loss: |
|
if input_ids is None: |
|
raise ValueError("Input IDs must be provided when `return_loss=True`.") |
|
|
|
|
|
num_text_tokens = input_ids.shape[1] |
|
shift_logits = lm_logits[:, -num_text_tokens:-1, :].contiguous() |
|
shift_labels = input_ids[:, 1:].contiguous() |
|
if attention_mask is not None: |
|
shift_attention_mask = attention_mask[:, -num_text_tokens:] |
|
shift_attention_mask = shift_attention_mask[:, 1:] |
|
else: |
|
shift_attention_mask = torch.ones(shift_labels.shape, dtype=bool, device=self.device) |
|
shift_logits = shift_logits[shift_attention_mask.bool()] |
|
shift_labels = shift_labels[shift_attention_mask.bool()] |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return JatOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|
|
def output_rl( |
|
self, |
|
transformer_outputs, |
|
continuous_observations: Optional[FloatTensor] = None, |
|
discrete_observations: Optional[LongTensor] = None, |
|
image_observations: Optional[FloatTensor] = None, |
|
continuous_actions: Optional[FloatTensor] = None, |
|
discrete_actions: Optional[LongTensor] = None, |
|
rewards: Optional[FloatTensor] = None, |
|
attention_mask: Optional[BoolTensor] = None, |
|
return_loss: bool = True, |
|
return_dict: Optional[bool] = None, |
|
loss_weight: Optional[FloatTensor] = None, |
|
): |
|
hidden_states = transformer_outputs.last_hidden_state |
|
loss, observation_loss, action_loss = None, None, None |
|
|
|
assert rewards is not None |
|
observations_mask = attention_mask[:, 1::2] if attention_mask is not None else None |
|
if continuous_observations is not None: |
|
if self.observation_loss_coef == 0.0: |
|
warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.") |
|
pred_observations = None |
|
observation_loss = 0.0 |
|
else: |
|
obs_size = continuous_observations.shape[-1] |
|
continuous_observations = torch.cat((continuous_observations, rewards.unsqueeze(-1)), dim=-1) |
|
continuous_observations = cyclic_expand_dim(continuous_observations, self.config.max_continuous_size) |
|
pred_observations = self.continuous_decoder(hidden_states[:, 1::2]) |
|
if return_loss: |
|
observation_loss = compute_mse_loss( |
|
pred_observations[:, :-1], |
|
continuous_observations[:, 1:], |
|
observations_mask[:, 1:] if observations_mask is not None else None, |
|
weights=loss_weight[:, 1:] if loss_weight is not None else None, |
|
) |
|
pred_observations = pred_observations[..., :obs_size] |
|
elif discrete_observations is not None: |
|
if self.observation_loss_coef == 0.0: |
|
warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.") |
|
pred_observations = None |
|
observation_loss = 0.0 |
|
else: |
|
warnings.warn("Discrete observations prediction are not supported yet.") |
|
pred_observations = None |
|
observation_loss = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif image_observations is not None: |
|
if self.observation_loss_coef == 0.0: |
|
warnings.warn("observation_loss_coef is 0.0, skipping memory-intensive observations prediction.") |
|
pred_observations = None |
|
observation_loss = 0.0 |
|
else: |
|
pred_observations = self.image_decoder(hidden_states[:, 1::2]) |
|
if return_loss: |
|
observation_loss = compute_mse_loss( |
|
pred_observations[:, :-1], |
|
image_observations[:, 1:], |
|
observations_mask[:, 1:] if observations_mask is not None else None, |
|
weights=loss_weight[:, 1:] if loss_weight is not None else None, |
|
) |
|
|
|
|
|
actions_mask = attention_mask[:, ::2] if attention_mask is not None else None |
|
if continuous_actions is not None: |
|
act_size = continuous_actions.shape[-1] |
|
continuous_actions = cyclic_expand_dim(continuous_actions, self.config.max_continuous_size) |
|
pred_actions = self.continuous_decoder(hidden_states[:, ::2]) |
|
if return_loss: |
|
action_loss = compute_mse_loss(pred_actions, continuous_actions, actions_mask, weights=loss_weight) |
|
pred_actions = pred_actions[..., :act_size] |
|
elif discrete_actions is not None: |
|
pred_actions = self.single_discrete_decoder(hidden_states[:, ::2]) |
|
if return_loss: |
|
action_loss = compute_ce_loss(pred_actions, discrete_actions, actions_mask, weights=loss_weight) |
|
|
|
|
|
if return_loss: |
|
loss = self.observation_loss_coef * observation_loss + self.action_loss_coef * action_loss |
|
|
|
if not return_dict: |
|
output = (pred_observations, pred_actions) + transformer_outputs[1:] |
|
return ((loss, observation_loss, action_loss) + output) if loss is not None else output |
|
|
|
return JatOutput( |
|
loss=loss, |
|
observation_loss=observation_loss, |
|
action_loss=action_loss, |
|
pred_observations=pred_observations, |
|
pred_actions=pred_actions, |
|
past_key_values=transformer_outputs.past_key_values, |
|
hidden_states=transformer_outputs.hidden_states, |
|
attentions=transformer_outputs.attentions, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[LongTensor] = None, |
|
pixel_values: Optional[FloatTensor] = None, |
|
continuous_observations: Optional[FloatTensor] = None, |
|
discrete_observations: Optional[LongTensor] = None, |
|
image_observations: Optional[FloatTensor] = None, |
|
continuous_actions: Optional[FloatTensor] = None, |
|
discrete_actions: Optional[LongTensor] = None, |
|
rewards: Optional[FloatTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[FloatTensor]]] = None, |
|
attention_mask: Optional[BoolTensor] = None, |
|
token_type_ids: Optional[LongTensor] = None, |
|
position_ids: Optional[LongTensor] = None, |
|
return_loss: bool = True, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
loss_weight: Optional[FloatTensor] = None, |
|
) -> JatOutput: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None or pixel_values is not None: |
|
inputs_embeds, attention_mask = self.embed_textual(input_ids, pixel_values, attention_mask) |
|
|
|
elif ( |
|
continuous_observations is not None or discrete_observations is not None or image_observations is not None |
|
): |
|
inputs_embeds, attention_mask = self.embed_rl( |
|
continuous_observations, |
|
discrete_observations, |
|
image_observations, |
|
continuous_actions, |
|
discrete_actions, |
|
rewards, |
|
attention_mask, |
|
) |
|
else: |
|
raise ValueError("Input not provided.") |
|
|
|
|
|
transformer_outputs = self.transformer( |
|
past_key_values=past_key_values, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
if input_ids is not None or pixel_values is not None: |
|
return self.output_textual(transformer_outputs, input_ids, attention_mask, return_loss, return_dict) |
|
else: |
|
return self.output_rl( |
|
transformer_outputs, |
|
continuous_observations, |
|
discrete_observations, |
|
image_observations, |
|
continuous_actions, |
|
discrete_actions, |
|
rewards, |
|
attention_mask, |
|
return_loss, |
|
return_dict, |
|
loss_weight, |
|
) |
|
|
|
def reset_rl(self): |
|
self._last_key_values = None |
|
self.last_discrete_observation = None |
|
self.last_continuous_observation = None |
|
self.last_text_observation = None |
|
self.last_image_observation = None |
|
self.last_discrete_action = None |
|
self.last_continuous_action = None |
|
self.last_reward = None |
|
|
|
@torch.no_grad() |
|
def get_next_action( |
|
self, |
|
processor: JatProcessor, |
|
continuous_observation: Optional[List[float]] = None, |
|
discrete_observation: Optional[List[int]] = None, |
|
text_observation: Optional[str] = None, |
|
image_observation: Optional[np.ndarray] = None, |
|
action_space: Union[spaces.Box, spaces.Discrete] = None, |
|
reward: Optional[float] = None, |
|
deterministic: bool = False, |
|
context_window: Optional[int] = None, |
|
): |
|
|
|
max_length = self.config.max_position_embeddings // 2 |
|
|
|
|
|
def to_list(x): |
|
return x.tolist() if isinstance(x, np.ndarray) else x |
|
|
|
continuous_observation = to_list(continuous_observation) |
|
discrete_observation = to_list(discrete_observation) |
|
|
|
|
|
if isinstance(action_space, spaces.Box): |
|
fake_continuous_action = [0.0 for _ in range(action_space.shape[0])] |
|
fake_discrete_action = None |
|
elif isinstance(action_space, spaces.Discrete): |
|
fake_continuous_action = None |
|
fake_discrete_action = 0 |
|
|
|
continuous_observations = [continuous_observation] if continuous_observation is not None else None |
|
discrete_observations = [discrete_observation] if discrete_observation is not None else None |
|
text_observations = [text_observation] if text_observation is not None else None |
|
image_observations = [image_observation] if image_observation is not None else None |
|
continuous_actions = [fake_continuous_action] if fake_continuous_action is not None else None |
|
discrete_actions = [fake_discrete_action] if fake_discrete_action is not None else None |
|
rewards = [reward] if reward is not None else [0.0] |
|
|
|
if self._last_key_values is not None: |
|
|
|
continuous_observations = ( |
|
[self.last_continuous_observation] + continuous_observations |
|
if continuous_observations is not None |
|
else None |
|
) |
|
discrete_observations = ( |
|
[self.last_discrete_observation] + discrete_observations if discrete_observations is not None else None |
|
) |
|
text_observations = ( |
|
[self.last_text_observation] + text_observations if text_observations is not None else None |
|
) |
|
image_observations = ( |
|
[self.last_image_observation] + image_observations if image_observations is not None else None |
|
) |
|
continuous_actions = ( |
|
[self.last_continuous_action] + continuous_actions if continuous_actions is not None else None |
|
) |
|
discrete_actions = [self.last_discrete_action] + discrete_actions if discrete_actions is not None else None |
|
rewards = [self.last_reward] + rewards |
|
|
|
|
|
self.last_continuous_observation = continuous_observations[-1] if continuous_observations is not None else None |
|
self.last_discrete_observation = discrete_observations[-1] if discrete_observations is not None else None |
|
self.last_text_observation = text_observations[-1] if text_observations is not None else None |
|
self.last_image_observation = image_observations[-1] if image_observations is not None else None |
|
self.last_reward = rewards[-1] |
|
|
|
|
|
continuous_observations = [continuous_observations] if continuous_observations is not None else None |
|
discrete_observations = [discrete_observations] if discrete_observations is not None else None |
|
text_observations = [text_observations] if text_observations is not None else None |
|
image_observations = [image_observations] if image_observations is not None else None |
|
continuous_actions = [continuous_actions] if continuous_actions is not None else None |
|
discrete_actions = [discrete_actions] if discrete_actions is not None else None |
|
rewards = [rewards] |
|
|
|
|
|
processed = processor( |
|
continuous_observations=continuous_observations, |
|
discrete_observations=discrete_observations, |
|
text_observations=text_observations, |
|
image_observations=image_observations, |
|
continuous_actions=continuous_actions, |
|
discrete_actions=discrete_actions, |
|
rewards=rewards, |
|
truncation=True, |
|
truncation_side="left", |
|
max_length=max_length, |
|
return_tensors="pt", |
|
) |
|
processed.to(self.device) |
|
|
|
|
|
outputs = self(**processed, past_key_values=self._last_key_values, return_loss=False) |
|
|
|
|
|
self._last_key_values = tuple( |
|
tuple(pkv[:, :, -self.config.max_position_embeddings + 2 :] for pkv in pkvs) |
|
for pkvs in outputs.past_key_values |
|
) |
|
|
|
|
|
self._last_key_values = tuple(tuple(pkv[:, :, :-2] for pkv in pkvs) for pkvs in self._last_key_values) |
|
|
|
|
|
if context_window is not None: |
|
self._last_key_values = tuple( |
|
tuple(pkv[:, :, -context_window:] for pkv in pkvs) for pkvs in self._last_key_values |
|
) |
|
|
|
|
|
if continuous_actions is not None: |
|
self.last_continuous_action = outputs.pred_actions[0, -1].cpu().tolist() |
|
return self.last_continuous_action |
|
elif discrete_actions is not None: |
|
logits = outputs.pred_actions[0, -1, : action_space.n] |
|
if deterministic: |
|
self.last_discrete_action = logits.argmax().cpu().item() |
|
else: |
|
self.last_discrete_action = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[0].item() |
|
return self.last_discrete_action |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, pixel_values=None, past_key_values=None, **kwargs): |
|
|
|
if past_key_values is not None: |
|
pixel_values = None |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
|
|
model_inputs = { |
|
"input_ids": input_ids, |
|
"pixel_values": pixel_values, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
} |
|
|
|
return model_inputs |
|
|
|
|
|
JatModel.register_for_auto_class("AutoModelForCausalLM") |
|
|