import math import mup import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from tqdm import tqdm from transformers.utils import ModelOutput from genie.factorization_utils import FactorizedEmbedding, factorize_labels from genie.config import GenieConfig from genie.st_transformer import STTransformerDecoder from genie.attention import BasicCrossAttention def modulate(x, shift, scale): return x * (1 + scale) + shift class TokenResampler(nn.Module): """TokenResampler or Action Stem""" def __init__(self, token_num, d_model, k_model, num_heads=8): super().__init__() """ initialize cross attention module and the learnable tokens """ self.token_num = token_num self.tokens = nn.Parameter(torch.randn(1, token_num, d_model) * 0.01) # nn.Parameter(torch.zeros(1, token_num, d_model)) self.cross_attention = BasicCrossAttention( num_heads=num_heads, d_model=d_model, k_model=k_model, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Computes the latent representations of input data by attention. """ # Initial reshape to adapt to token dimensions (B, T, D) # Replicating tokens for each item in the batch and computing cross-attention B, T, D = x.shape x = x.view(-1, 1, D) output_tokens = self.tokens.repeat(len(x), 1, 1) # (32, 16, 128) output_tokens = self.cross_attention(output_tokens, x, x) # (32, 16, 128) return rearrange(output_tokens, "(b t) s d -> b t s d", b=B) class ModulateLayer(nn.Module): """ Modified from the final layer adopted from DiT with token-wise modulation. """ def __init__(self, model_channels, out_channels): super().__init__() self.norm_final = nn.LayerNorm(out_channels, elementwise_affine=False, eps=1e-6) self.linear_out = nn.Linear(out_channels, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.Linear(model_channels, model_channels), nn.SiLU(), nn.Linear(model_channels, 2 * out_channels, bias=True) ) self.apply(self._init_weights) def forward(self, x, c): """ a simple modulation """ x_shape = x.shape x = rearrange(x, "(b s) t d -> b s t d", b=len(c)) c = c[:, None, :x_shape[2]] shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear_out(x) return x.view(x_shape) def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight, gain=0.1) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) class BasicMLP(nn.Module): def __init__(self, d_action, d_model): super().__init__() self.model = nn.Sequential(nn.Linear(d_action, d_model, bias=True), nn.LayerNorm(d_model), nn.ReLU(), nn.Linear(d_model, d_model, bias=True)) self.apply(self._init_weights) def forward(self,x): return self.model(x) def _init_weights(self, m): # TODO: muP? if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight, gain=0.01) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): if m.bias is not None: nn.init.constant_(m.bias, 0) if m.weight is not None: nn.init.constant_(m.weight, 1.0) def cosine_schedule(u): """ u in [0, 1] """ if isinstance(u, torch.Tensor): cls = torch elif isinstance(u, float): cls = math else: raise NotImplementedError(f"Unexpected {type(u)=} {u=}") return cls.cos(u * cls.pi / 2) class ActionStat(nn.Module): def __init__(self, input_info): super().__init__() self.register_buffer("mean", torch.FloatTensor(input_info[0])) self.register_buffer("std", torch.FloatTensor(input_info[1])) def forward(self, x): # x: (B, T, S * D). T window length, S is the stride in the datasets, D action dimensions x = rearrange(x, "b t (s d) -> b t s d", d=len(self.mean)) x = (x - self.mean) / (self.std + 1e-10) return rearrange(x, "b t s d -> b t (s d)", d=len(self.mean)) def extra_repr(self): return f"mean={self.mean}, std={self.std}" def unnormalize(self, actions): """ unnormalize the actions """ actions = rearrange(actions, "b t (s d) -> b t s d", d=len(self.mean)) actions = actions * (self.std + 1e-10) + self.mean return rearrange(actions, "b t s d -> b t (s d)", d=len(self.mean)) class STMaskGIT(nn.Module, PyTorchModelHubMixin): # Next-Token prediction as done in https://arxiv.org/pdf/2402.15391.pdf def __init__(self, config: GenieConfig): super().__init__() self.h = self.w = math.isqrt(config.S) assert self.h**2 == config.S, "Expected S to be square" # STTransformerDecoder self.decoder = STTransformerDecoder( num_layers=config.num_layers, num_heads=config.num_heads, d_model=config.d_model, qkv_bias=config.qkv_bias, proj_bias=config.proj_bias, qk_norm=config.qk_norm, use_mup=config.use_mup, attn_drop=config.attn_drop, mlp_ratio=config.mlp_ratio, mlp_bias=config.mlp_bias, mlp_drop=config.mlp_drop, action_processing=config.action_network, random_dummy_action=config.random_dummy_action, jointly_predict_actions=config.jointly_predict_actions, mask_token_id=config.image_vocab_size ) # learnable embedding for the maximum image sizes self.pos_embed_TSC = torch.nn.Parameter(torch.zeros(1, config.T, config.S + config.action_token_size, config.d_model)) print(f"{self.h=} {self.w=} {config.S=} {config.T=} {config.d_model=}") self.mask_token_id = config.image_vocab_size self.seq_len = config.S self.relevant_action_mask = None self.token_embed = FactorizedEmbedding( # also works for num_factored_vocabs = 1 factored_vocab_size=config.factored_vocab_size, num_factored_vocabs=config.num_factored_vocabs, d_model=config.d_model, mask_token_id=self.mask_token_id, ) cls = FixedMuReadout if config.use_mup else nn.Linear # (Fixed)MuReadout might slow dow down compiled training? self.out_x_proj = cls(config.d_model, config.factored_vocab_size * config.num_factored_vocabs) self.config = config self.action_mask_tokens = torch.nn.Parameter(torch.zeros(1, config.T, 1, config.d_model)) if (self.config.init_actions or self.config.use_actions) and self.config.action_domains is not None: self.init_action_projectors(self.config.action_domains, self.config.d_actions, self.config.action_stats, self.config.action_network) def init_action_projectors( self, domains: list[str], d_actions: list[int], action_stats: list[list[list[float]]], action_network: str = "mlp", use_diffusion: bool = False, ): # initialize the action stems. It's called externally for training. # assert len(domains) == len(d_actions) self.config.init_actions = True self.config.action_domains = domains self.config.d_actions = d_actions self.config.action_stats = action_stats self.action_preprocessor = nn.ModuleDict() self.action_mlp = nn.ModuleDict() self.action_out_projectors = nn.ModuleDict() # initialize for every layer print("use diffusion: ", use_diffusion) print("init action network:", action_network) cls = FixedMuReadout if self.config.use_mup else nn.Linear # (Fixed)MuReadout might slow dow down compiled training? # We currently skip datasets if they fail but `domains` is all specified datasets, so we get misalignment in this case assert len(domains) == len(d_actions) == len(action_stats), f"{len(domains)=} {len(d_actions)=} {len(action_stats)=}" for domain, d_action, action_stat in zip(domains, d_actions, action_stats): # by default, we share these modules across layers self.action_preprocessor[domain] = ActionStat(action_stat) self.action_mlp[domain] = BasicMLP(d_action, self.config.d_model) if not use_diffusion: self.action_out_projectors[domain] = cls(self.config.d_model, d_action) # by default, the conditioning are separate for each layer for layer in self.decoder.layers: layer.action_projectors = nn.ModuleDict() for domain, d_action, action_stat in zip(domains, d_actions, action_stats): if "mlp" in action_network: layer.action_projectors[domain] = nn.Identity() elif "cross_attention" in action_network: layer.action_projectors[domain] = BasicCrossAttention( num_heads=8, d_model=self.config.d_model, k_model=d_action ) elif "modulate" in action_network: layer.action_projectors[domain] = ModulateLayer(self.config.d_model, self.config.d_model) def generate( self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, max_new_tokens: int, min_new_tokens: int = None, return_logits: bool = False, return_with_actions: bool = False, maskgit_steps: int = 1, temperature: float = 0.0, action_ids: torch.Tensor = None, domain: str = "default", **kwargs ) -> tuple[torch.LongTensor, torch.FloatTensor]: """ Args designed to match the format of Llama. We ignore `attention_mask`, and use `max_new_tokens` to determine the number of frames to generate. Returns: `(sample_THW, factored_logits)` if `return_logits` else `sample_THW` sample_THW: size (B, num_new_frames * H * W) corresponding to autoregressively generated unfactorized token ids for future frames. Optionally, factored_logits: size (B, factored_vocab_size, num_factored_vocabs, num_new_frames, H, W). """ assert min_new_tokens in (None, max_new_tokens), \ "Expecting `min_new_tokens`, if specified, to match `max_new_tokens`." # assert max_new_tokens % self.config.S == 0, "Expecting `max_new_tokens` to be a multiple of `self.config.S`." h, w = self.h, self.w if "h" in kwargs: h = kwargs["h"][0] if "w" in kwargs: w = kwargs["w"][0] S = h * w num_new_frames = max_new_tokens // S inputs_THW = rearrange(input_ids.clone(), "b (t h w) -> b t h w", h=h, w=w) inputs_masked_THW = torch.cat([ inputs_THW, torch.full((input_ids.size(0), num_new_frames, h, w), self.mask_token_id, dtype=torch.long, device=input_ids.device) ], dim=1) all_factored_logits = [] for timestep in range(inputs_THW.size(1), inputs_THW.size(1) + num_new_frames): # could change sampling hparams sample_HW, factored_logits, actions = self.maskgit_generate( inputs_masked_THW, timestep, maskgit_steps=maskgit_steps, temperature=temperature, action_ids=action_ids, domain=domain, **kwargs ) inputs_masked_THW[:, timestep] = sample_HW all_factored_logits.append(factored_logits) predicted_tokens = rearrange(inputs_masked_THW, "B T H W -> B (T H W)") if return_with_actions: # unnormalize actions actions = self.action_preprocessor[domain[0]].unnormalize(actions) return predicted_tokens, actions elif return_logits: return predicted_tokens, torch.stack(all_factored_logits, dim=3) else: return predicted_tokens def init_mask(self, prompt_THW, t=1): # since we generate 1 image at a time, the mask should be for a single frame, not across all frames. T, H, W = prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) # self.seq_len unmasked = torch.zeros(prompt_THW.size(0), t * self.seq_len, dtype=torch.bool, device=prompt_THW.device) return unmasked @torch.no_grad() def maskgit_generate( self, prompt_THW: torch.LongTensor, out_t: int, maskgit_steps: int = 1, temperature: float = 0.0, unmask_mode: str = "random", action_ids=None, domain="default", **kwargs ) -> tuple[torch.LongTensor, torch.FloatTensor]: """ Performs MaskGIT-style inference to predict frame `out_t`. Args: prompt_THW: Unfactorized token ids, size (B, T, H, W) out_t: Will return predicted unfactorized token ids for this frame. Should be >= 1 as the 0th frame is assumed to be given. Expects all future frames to be fully masked. maskgit_steps: The number of MaskGIT-style inference steps to take. temperature: Sampling temperature. In the factorized case, sampling is performed for each factorized vocabulary independently. If temperature is <= 1e-8, will be greedy (i.e. argmax) instead of actual sampling. unmask_mode: The method to determine tokens to unmask during each step of MaskGIT inference. Options: - "greedy" for unmasking the most confident tokens, which is matches the original MaskGIT - "random" for randomly choosing tokens to unmask "greedy" tends to copy the previous frame, so we default to "random" instead. Returns: (sample_HW, factored_logits) sample_HW: size (B, H, W) corresponding to predicted unfactorized token ids for frame `out_t`. factored_logits: size (B, factored_vocab_size, num_factored_vocabs, H, W). """ # assume we have pre-masked z{out_t}...zT with all masks assert out_t, "maskgit_generate requires out_t > 0" assert torch.all(prompt_THW[:, out_t:] == self.mask_token_id), \ f"when generating z{out_t}, frames {out_t} and later must be masked" bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) S = h * w # this will be modified in place on each iteration of this loop unmasked = self.init_mask(prompt_THW) logits_CTHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) logits_CHW = logits_CTHW[:, :, out_t] orig_logits_CHW = logits_CHW.clone() # Return these original logits, not logits after partially sampling. for step in range(maskgit_steps): # Perform a single maskgit step (cosine schedule), updating unmasked in-place if step > 0: # recompute logits with updated prompt # action is one step out so this line is doing it again. logits_CHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) logits_CHW = logits_CHW[:, :, out_t] factored_logits = rearrange(logits_CHW, "b (num_vocabs vocab_size) h w -> b vocab_size num_vocabs h w", vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs) factored_probs = torch.nn.functional.softmax(factored_logits, dim=1) samples_HW = torch.zeros((bs, h, w), dtype=torch.long, device=prompt_THW.device) confidences_HW = torch.ones((bs, h, w), dtype=torch.float, device=prompt_THW.device) for probs in factored_probs.flip(2).unbind(2): if temperature <= 1e-8: # greedy sampling sample = probs.argmax(dim=1) else: # Categorical expects last dim to be channel dim dist = torch.distributions.categorical.Categorical( probs=rearrange(probs, "b vocab_size ... -> b ... vocab_size") / temperature ) sample = dist.sample() samples_HW *= self.config.factored_vocab_size samples_HW += sample confidences_HW *= torch.gather(probs, 1, sample.unsqueeze(1)).squeeze(1) prev_unmasked = unmasked.clone() prev_img_flat = rearrange(prompt_THW[:, out_t], "B H W -> B (H W)") samples_flat = samples_HW.reshape(bs, S) if step != maskgit_steps - 1: # skip masking for last maskgit step # use cosine mask scheduling function, n is how many of frame out_t to mask n = math.ceil(cosine_schedule((step + 1) / maskgit_steps) * S) if unmask_mode == "greedy": # set the n patches with the least confidence to mask_token confidences_flat = confidences_HW.reshape(bs, S) elif unmask_mode == "random": # randomize confidences, so that patches are randomly masked confidences_flat = torch.rand_like(confidences_HW).reshape(bs, S) # not probability distribution anymore, but only relative order matters else: raise NotImplementedError(f"Expected `unmask_mode` to be one of ['greedy', 'random'], " f"got {unmask_mode}") confidences_flat[unmasked] = torch.inf least_confident_tokens = torch.argsort(confidences_flat, dim=1) # unmask the (self.config.S - n) most confident tokens unmasked.scatter_(1, least_confident_tokens[:, n:], True) samples_flat.scatter_(1, least_confident_tokens[:, :n], self.mask_token_id) # copy previously unmasked values from prompt input into sample samples_flat[prev_unmasked] = prev_img_flat[prev_unmasked] samples_HW = samples_flat.reshape(-1, h, w) # feed back to iteratively decode prompt_THW[:, out_t] = samples_HW # Return the final sample and logits return samples_HW, rearrange( orig_logits_CHW, "B (num_vocabs vocab_size) H W -> B vocab_size num_vocabs H W", vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs, H=h, W=w ), action_outputs @torch.no_grad() def maskgit_generate_horizon( self, prompt_THW: torch.LongTensor, out_t_min: int, out_t_max: int, maskgit_steps: int = 1, temperature: float = 0.0, unmask_mode: str = "random", action_ids=None, domain="default", skip_normalization: bool = False, **kwargs ) -> tuple[torch.LongTensor, torch.FloatTensor]: """ Performs MaskGIT-style inference to predict frame `out_t`. Args: prompt_THW: Unfactorized token ids, size (B, T, H, W) out_t: Will return predicted unfactorized token ids for this frame. Should be >= 1 as the 0th frame is assumed to be given. Expects all future frames to be fully masked. maskgit_steps: The number of MaskGIT-style inference steps to take. temperature: Sampling temperature. In the factorized case, sampling is performed for each factorized vocabulary independently. If temperature is <= 1e-8, will be greedy (i.e. argmax) instead of actual sampling. unmask_mode: The method to determine tokens to unmask during each step of MaskGIT inference. Options: - "greedy" for unmasking the most confident tokens, which is matches the original MaskGIT - "random" for randomly choosing tokens to unmask "greedy" tends to copy the previous frame, so we default to "random" instead. Returns: (sample_HW, factored_logits) sample_HW: size (B, H, W) corresponding to predicted unfactorized token ids for frame `out_t`. factored_logits: size (B, factored_vocab_size, num_factored_vocabs, H, W). """ # assume we have pre-masked z{out_t}...zT with all masks assert out_t, "maskgit_generate requires out_t > 0" assert torch.all(prompt_THW[:, out_t:] == self.mask_token_id), \ f"when generating z{out_t}, frames {out_t} and later must be masked" bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) S = h * w # this will be modified in place on each iteration of this loop unmasked = self.init_mask(prompt_THW) logits_CTHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) logits_CHW = logits_CTHW[:, :, out_t_min:out_t_max] orig_logits_CHW = logits_CHW.clone() # Return these original logits, not logits after partially sampling. for step in tqdm(range(maskgit_steps)): # Perform a single maskgit step (cosine schedule), updating unmasked in-place if step > 0: # recompute logits with updated prompt # action is one step out so this line is doing it again. logits_CHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) logits_CHW = logits_CHW[:, :, out_t_min:out_t_max] factored_logits = rearrange(logits_CHW, "b (num_vocabs vocab_size) h w -> b vocab_size num_vocabs h w", vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs) factored_probs = torch.nn.functional.softmax(factored_logits, dim=1) samples_HW = torch.zeros((bs, h, w), dtype=torch.long, device=prompt_THW.device) confidences_HW = torch.ones((bs, h, w), dtype=torch.float, device=prompt_THW.device) for probs in factored_probs.flip(2).unbind(2): if temperature <= 1e-8: # greedy sampling sample = probs.argmax(dim=1) else: # Categorical expects last dim to be channel dim dist = torch.distributions.categorical.Categorical( probs=rearrange(probs, "b vocab_size ... -> b ... vocab_size") / temperature ) sample = dist.sample() samples_HW *= self.config.factored_vocab_size samples_HW += sample confidences_HW *= torch.gather(probs, 1, sample.unsqueeze(1)).squeeze(1) prev_unmasked = unmasked.clone() prev_img_flat = rearrange(prompt_THW[:, out_t_min:out_t_max], "B H W -> B (H W)") samples_flat = samples_HW.reshape(bs, S) if step != maskgit_steps - 1: # skip masking for last maskgit step # use cosine mask scheduling function, n is how many of frame out_t to mask n = math.ceil(cosine_schedule((step + 1) / maskgit_steps) * S) if unmask_mode == "greedy": # set the n patches with the least confidence to mask_token confidences_flat = confidences_HW.reshape(bs, S) elif unmask_mode == "random": # randomize confidences, so that patches are randomly masked confidences_flat = torch.rand_like(confidences_HW).reshape(bs, S) # not probability distribution anymore, but only relative order matters else: raise NotImplementedError(f"Expected `unmask_mode` to be one of ['greedy', 'random'], " f"got {unmask_mode}") confidences_flat[unmasked] = torch.inf least_confident_tokens = torch.argsort(confidences_flat, dim=1) # unmask the (self.config.S - n) most confident tokens unmasked.scatter_(1, least_confident_tokens[:, n:], True) samples_flat.scatter_(1, least_confident_tokens[:, :n], self.mask_token_id) # copy previously unmasked values from prompt input into sample samples_flat[prev_unmasked] = prev_img_flat[prev_unmasked] samples_HW = samples_flat.reshape(-1, h, w) # feed back to iteratively decode prompt_THW[:, out_t_min:out_t_max] = samples_HW # Return the final sample and logits return samples_HW, rearrange( orig_logits_CHW, "B (num_vocabs vocab_size) H W -> B vocab_size num_vocabs H W", vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs, H=h, W=w ), action_outputs def compute_video_loss_and_acc(self, logits_CTHW, targets_THW, relevant_mask_THW): # Video token prediction T, H, W = self.config.T, self.h, self.w targets_THW = targets_THW.clone() targets_THW = rearrange(targets_THW, "B (T H W) -> B T H W", T=T, H=H, W=W) logits_CTHW, targets_THW = logits_CTHW[:, :, 1:], targets_THW[:, 1:] # first frame always unmasked factored_logits = rearrange(logits_CTHW, "b (num_vocabs vocab_size) t h w -> b vocab_size num_vocabs t h w", vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs) factored_targets = factorize_labels(targets_THW) # adding label_smoothing loss_THW = F.cross_entropy(factored_logits, factored_targets, reduction="none", label_smoothing=0.01).sum(dim=1) acc_THW = (factored_logits.argmax(dim=1) == factored_targets).all(dim=1) # Compute the mean masked error. # Multiply loss values by mask instead of indexing them, more computationally efficient. num_masked_tokens = torch.sum(relevant_mask_THW) relevant_loss = torch.sum(loss_THW * relevant_mask_THW) / num_masked_tokens relevant_acc = torch.sum(acc_THW * relevant_mask_THW).float() / num_masked_tokens # only optimize on the masked/noised logits. return relevant_loss, relevant_acc def compute_logits(self, x_THW: torch.Tensor, action_ids: torch.Tensor = None, domain = None, **kwargs): # x_THW is for z0,...,zT while x_targets is z1,...,zT h, w = self.h, self.w if "h" in kwargs: assert "w" in kwargs h = kwargs["h"][0] w = kwargs["w"][0] x_TS = rearrange(x_THW, "B T H W -> B T (H W)") x_TSC = self.token_embed(x_TS) T = x_TSC.shape[1] if action_ids is not None: # currently, action_preprocessor just normalizes the actions skip_normalization = kwargs.get("skip_normalization", False) if not skip_normalization: action_ids = self.action_preprocessor[domain[0]](action_ids) action_ids = self.action_mlp[domain[0]](action_ids) # [B, T, D] if "concat" in self.config.action_network: # randomly dropped the conditioning action_condition = action_ids[:, :T, None].repeat(1, 1, self.config.action_token_size, 1) # [B, T, S, D] if self.relevant_action_mask is not None and self.config.jointly_predict_actions: action_condition = self.relevant_action_mask[:, :T] * self.action_mask_tokens[:, :T] + \ (1 - self.relevant_action_mask[:, :T]) * action_condition[:, :T] x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, D] elif self.config.jointly_predict_actions: # all masked there is no input actions and try to predict actions as in policy action_condition = self.action_mask_tokens[:, :T].repeat(1, 1, self.config.action_token_size, 1) x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, D] # additive position embeddings, using the same vocab space domain = domain[0] if domain is not None else None x_TSC = self.decoder(x_TSC + self.pos_embed_TSC[:, :x_TSC.shape[1], :x_TSC.shape[2]], action_ids=action_ids, domain=domain) decoded_actions = None decoded_states = None if self.config.jointly_predict_actions: decoded_actions = x_TSC[:, :, -self.config.action_token_size:].mean(dim=2) # pool all tokens decoded_actions = self.action_out_projectors[domain](decoded_actions) if self.config.jointly_predict_states: x_TSC = x_TSC[:, :, :h*w] # remove action tokens x_next_TSC = self.out_x_proj(x_TSC) decoded_states = rearrange(x_next_TSC, "B T (H W) C -> B C T H W", H=h, W=w) # break into actions here return decoded_states, decoded_actions def forward(self, input_ids, labels, action_ids=None, domain="default", **kwargs): """ input_ids: size (B, T * H * W) represents video sequences labels: size (B, T * H * W) represents video sequences action_ids: size (B, T, Da) represents action sequences """ # if h and w in kwargs, update them. support varying resolutions. T, H, W = self.config.T, self.h, self.w if "h" in kwargs: H = kwargs["h"][0] if "w" in kwargs: W = kwargs["w"][0] x_THW = rearrange(input_ids, "B (T H W) -> B T H W", T=T, H=H, W=W) if action_ids is not None: action_labels = action_ids.clone() # in training we add masked tokens between 0 (fully unmasked as in video pred) and 1 (fully masked as in policies) for training losses action_mask = torch.zeros_like(action_ids) # while action_mask.sum() == 0 or action_mask.sum() == action_mask.numel(): drop_ratio = torch.rand(len(action_ids), 1, 1) action_mask = torch.rand(len(action_ids), T, 1) < drop_ratio self.relevant_action_mask = action_mask.unsqueeze(-1).cuda().to(x_THW.dtype) # Record the loss over masked tokens only to make it more comparable to LLM baselines logits_CTHW, action_outputs = self.compute_logits(x_THW, action_ids=action_ids, domain=domain, **kwargs) relevant_mask = x_THW[:, 1:] == self.mask_token_id # could also get mask of corrupted tokens by uncommenting line in `get_maskgit_collator` relevant_loss = torch.zeros(1).to(x_THW.device) relevant_acc = torch.zeros(1).to(x_THW.device) if logits_CTHW is not None: relevant_loss, relevant_acc = self.compute_video_loss_and_acc(logits_CTHW, labels, relevant_mask) if action_outputs is not None: action_loss = torch.nn.functional.mse_loss(action_labels, action_outputs, reduce="none") action_loss = (action_loss * self.relevant_action_mask[...,0]).mean() return ModelOutput( loss=relevant_loss, acc=relevant_acc, logits=logits_CTHW, action_loss=action_loss, actions=action_outputs) return ModelOutput(loss=relevant_loss, acc=relevant_acc, logits=logits_CTHW) def init_weights(self): """ Works with and without muP. """ std = 0.02 for module in self.modules(): if isinstance(module, nn.Linear): if hasattr(module.weight, "infshape"): # muP mup.normal_(module.weight, mean=0.0, std=std) else: module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def set_mup_shapes(self, rescale_params=False): base_config = self.config.shallow_copy() base_config.num_heads = 8 base_config.d_model = 256 # currently hardcoding to this shape base_model = STMaskGIT(base_config) mup.set_base_shapes(self, base_model, rescale_params=rescale_params) @classmethod def from_pretrained(cls, *args, **kwargs): """ Extra logic for muP. """ model = super().from_pretrained(*args, **kwargs) if model.config.use_mup: model.set_mup_shapes(rescale_params=False) return model class FixedMuReadout(mup.MuReadout): # add init_weights for FixedMuReadout def __init__(self, d_input, d_output): super().__init__(d_input, d_output) self.apply(self._init_weights) def _init_weights(self, m): # TODO: muP? if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight, gain=0.01) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): """ Using `return super(mup.MuReadout, self).forward(self.output_mult * x / self.width_mult())` with `torch.compile` results in two divisions by `self.width_mult()` for some reason """ # return F.linear(self.output_mult * x / self.width_mult(), self.weight, self.bias) # equivalent return nn.Linear.forward(self, self.output_mult * x / self.width_mult())