File size: 10,877 Bytes
21e5dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_llama_action import LlamaActionConfig
class LearnableFactorizedSpatioTemporalPositionalEmbedding(nn.Module):
def __init__(self, num_spatio_embeddings: int, num_temporal_embeddings: int, embedding_dim: int):
super().__init__()
self.spatio_embeddings = nn.Embedding(num_spatio_embeddings, embedding_dim)
self.temporal_embeddings = nn.Embedding(num_temporal_embeddings, embedding_dim)
self.num_spatio_embeddings = num_spatio_embeddings
self.num_temporal_embeddings = num_temporal_embeddings
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int):
seq_length = attention_mask.size(1)
batch_size = attention_mask.size(0)
if past_key_values_length == 0:
# create a tensor of the form [0, 1, 2, ..., num_spatio_embeddings-1]
spatio_indices = torch.arange(
self.num_spatio_embeddings,
device=attention_mask.device
).repeat(self.num_temporal_embeddings).unsqueeze(0).repeat((batch_size, 1))
# create a tensor of the form [0, 0, 0, ..., 1, 1, 1, ..., 2, 2, 2, ...]
temporal_indices = torch.arange(
self.num_temporal_embeddings,
device=attention_mask.device
).repeat_interleave(self.num_spatio_embeddings).unsqueeze(0).repeat((batch_size, 1))
spatio_indices = spatio_indices[:, :seq_length]
temporal_indices = temporal_indices[:, :seq_length]
else:
temporal_index = past_key_values_length // self.num_spatio_embeddings
spatio_index = past_key_values_length % self.num_spatio_embeddings
spatio_indices = torch.tensor([[spatio_index]], device=attention_mask.device).repeat((batch_size, 1))
temporal_indices = torch.tensor([[temporal_index]], device=attention_mask.device).repeat((batch_size, 1))
return self.spatio_embeddings(spatio_indices) + self.temporal_embeddings(temporal_indices)
class LlamaActionForCausalLM(LlamaForCausalLM):
config_class = LlamaActionConfig
def __init__(self, config: LlamaActionConfig):
super().__init__(config)
self.num_spatio_embeddings = config.num_spatio_embeddings
self.num_temporal_embeddings = config.num_temporal_embeddings
self.num_image_patches = config.num_image_patches
self.num_action_embeddings = config.num_action_embeddings
self.pos_embedding_spatio_temporal = LearnableFactorizedSpatioTemporalPositionalEmbedding(
config.num_spatio_embeddings, config.num_temporal_embeddings, config.hidden_size,
)
self.action_projection = nn.Linear(config.action_dim, config.hidden_size)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
actions: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
pass
elif inputs_embeds is not None:
pass
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
inputs_embeds = self.model.get_input_embeddings()(input_ids)
if past_key_values is None or len(past_key_values) == 0:
inputs_embeds_list = torch.split(
inputs_embeds,
split_size_or_sections=self.num_image_patches,
dim=1
)
actions_list = torch.split(
actions,
split_size_or_sections=self.num_action_embeddings,
dim=1
)
embeddings = []
if len(inputs_embeds_list) == len(actions_list):
# mostly used in training phase
for inputs_embeds, action_embeds in zip(inputs_embeds_list, actions_list):
action_features = self.action_projection(action_embeds)
embeddings.append(inputs_embeds)
embeddings.append(action_features)
elif len(inputs_embeds_list) < len(actions_list):
# used in inference phase (mostly)
for i, inputs_embeds in enumerate(inputs_embeds_list):
embeddings.append(inputs_embeds)
if i < len(inputs_embeds_list) - 1:
# the last frame might be generating image tokens, so we don't add action embedding
action_embeds = self.action_projection(actions_list[i])
embeddings.append(action_embeds)
if inputs_embeds_list[-1].size(1) == self.num_image_patches:
# if the last frame has generated all image tokens, we add action embedding
action_embeds = self.action_projection(actions_list[len(inputs_embeds_list) - 1])
embeddings.append(action_embeds)
else:
if isinstance(past_key_values, tuple):
past_key_values_length = past_key_values[0][0].size(2)
else:
past_key_values_length = past_key_values.get_seq_length()
embeddings = []
# create an interleaved sequence of image and action embeddings like image, image, ..., image, action, action, ..., action
# we only generate image tokens, so we add action tokens after generating one frame
if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
seq_index = past_key_values_length // self.num_spatio_embeddings + 1
actions_list = torch.split(
actions,
split_size_or_sections=self.num_action_embeddings,
dim=1
)
action_features = self.action_projection(actions_list[seq_index - 1])
embeddings.append(action_features)
embeddings.append(inputs_embeds)
else:
pass
if len(embeddings) > 0:
inputs_embeds = torch.cat(embeddings, dim=1)
# insert spatio-temporal positional embedding
if past_key_values is not None:
if isinstance(past_key_values, tuple):
past_key_values_length = past_key_values[0][0].size(2)
else:
past_key_values_length = past_key_values.get_seq_length()
else:
past_key_values_length = 0
inputs_embeds += self.pos_embedding_spatio_temporal(inputs_embeds, past_key_values_length)
outputs = self.model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output).contiguous()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
**kwargs):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
n_frames = seq_length // self.num_image_patches
attention_mask_length = n_frames * (self.num_image_patches + self.num_action_embeddings)
if seq_length % self.num_image_patches != 0:
n_last_frame_tokens = seq_length % self.num_image_patches
attention_mask_length += n_last_frame_tokens
else:
print(f"attempting to generate new frame - frame no: {n_frames + 1}")
attention_mask = torch.ones((batch_size, attention_mask_length), device=input_ids.device, dtype=torch.long)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None and len(past_key_values) > 0:
if isinstance(past_key_values, tuple):
past_length = past_key_values[0][0].size(2)
else:
past_length = past_key_values.get_seq_length()
if input_ids.size(1) > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.size(1) - 1
input_ids = input_ids[:, remove_prefix_length:]
seq_length = input_ids.size(1)
past_key_values_length = past_length
mask_seq_length = seq_length + past_key_values_length
if past_key_values_length % self.num_spatio_embeddings == (self.num_spatio_embeddings - self.num_action_embeddings):
mask_seq_length += self.num_action_embeddings
attention_mask = torch.ones((batch_size, mask_seq_length), device=input_ids.device, dtype=torch.long)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"actions": kwargs.get("actions"),
"past_key_values": past_key_values,
"use_cache": use_cache,
}
|