Spaces:
Sleeping
Sleeping
import math | |
import mup | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from huggingface_hub import PyTorchModelHubMixin | |
from tqdm import tqdm | |
from transformers.utils import ModelOutput | |
from genie.factorization_utils import FactorizedEmbedding, factorize_labels | |
from genie.config import GenieConfig | |
from genie.st_transformer import STTransformerDecoder | |
from genie.attention import BasicCrossAttention | |
def modulate(x, shift, scale): | |
return x * (1 + scale) + shift | |
class TokenResampler(nn.Module): | |
"""TokenResampler or Action Stem""" | |
def __init__(self, token_num, d_model, k_model, num_heads=8): | |
super().__init__() | |
""" initialize cross attention module and the learnable tokens """ | |
self.token_num = token_num | |
self.tokens = nn.Parameter(torch.randn(1, token_num, d_model) * 0.01) | |
# nn.Parameter(torch.zeros(1, token_num, d_model)) | |
self.cross_attention = BasicCrossAttention( | |
num_heads=num_heads, | |
d_model=d_model, | |
k_model=k_model, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Computes the latent representations of input data by attention. | |
""" | |
# Initial reshape to adapt to token dimensions (B, T, D) | |
# Replicating tokens for each item in the batch and computing cross-attention | |
B, T, D = x.shape | |
x = x.view(-1, 1, D) | |
output_tokens = self.tokens.repeat(len(x), 1, 1) # (32, 16, 128) | |
output_tokens = self.cross_attention(output_tokens, x, x) # (32, 16, 128) | |
return rearrange(output_tokens, "(b t) s d -> b t s d", b=B) | |
class ModulateLayer(nn.Module): | |
""" | |
Modified from the final layer adopted from DiT with token-wise modulation. | |
""" | |
def __init__(self, model_channels, out_channels): | |
super().__init__() | |
self.norm_final = nn.LayerNorm(out_channels, elementwise_affine=False, eps=1e-6) | |
self.linear_out = nn.Linear(out_channels, out_channels, bias=True) | |
self.adaLN_modulation = nn.Sequential( | |
nn.Linear(model_channels, model_channels), | |
nn.SiLU(), | |
nn.Linear(model_channels, 2 * out_channels, bias=True) | |
) | |
self.apply(self._init_weights) | |
def forward(self, x, c): | |
""" | |
a simple modulation | |
""" | |
x_shape = x.shape | |
x = rearrange(x, "(b s) t d -> b s t d", b=len(c)) | |
c = c[:, None, :x_shape[2]] | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear_out(x) | |
return x.view(x_shape) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight, gain=0.1) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
class BasicMLP(nn.Module): | |
def __init__(self, d_action, d_model): | |
super().__init__() | |
self.model = nn.Sequential(nn.Linear(d_action, d_model, bias=True), | |
nn.LayerNorm(d_model), | |
nn.ReLU(), | |
nn.Linear(d_model, d_model, bias=True)) | |
self.apply(self._init_weights) | |
def forward(self,x): | |
return self.model(x) | |
def _init_weights(self, m): # TODO: muP? | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight, gain=0.01) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
def cosine_schedule(u): | |
""" u in [0, 1] """ | |
if isinstance(u, torch.Tensor): | |
cls = torch | |
elif isinstance(u, float): | |
cls = math | |
else: | |
raise NotImplementedError(f"Unexpected {type(u)=} {u=}") | |
return cls.cos(u * cls.pi / 2) | |
class ActionStat(nn.Module): | |
def __init__(self, input_info): | |
super().__init__() | |
self.register_buffer("mean", torch.FloatTensor(input_info[0])) | |
self.register_buffer("std", torch.FloatTensor(input_info[1])) | |
def forward(self, x): | |
# x: (B, T, S * D). T window length, S is the stride in the datasets, D action dimensions | |
x = rearrange(x, "b t (s d) -> b t s d", d=len(self.mean)) | |
x = (x - self.mean) / (self.std + 1e-10) | |
return rearrange(x, "b t s d -> b t (s d)", d=len(self.mean)) | |
def extra_repr(self): | |
return f"mean={self.mean}, std={self.std}" | |
def unnormalize(self, actions): | |
""" unnormalize the actions """ | |
actions = rearrange(actions, "b t (s d) -> b t s d", d=len(self.mean)) | |
actions = actions * (self.std + 1e-10) + self.mean | |
return rearrange(actions, "b t s d -> b t (s d)", d=len(self.mean)) | |
class STMaskGIT(nn.Module, PyTorchModelHubMixin): | |
# Next-Token prediction as done in https://arxiv.org/pdf/2402.15391.pdf | |
def __init__(self, config: GenieConfig): | |
super().__init__() | |
self.h = self.w = math.isqrt(config.S) | |
assert self.h**2 == config.S, "Expected S to be square" | |
# STTransformerDecoder | |
self.decoder = STTransformerDecoder( | |
num_layers=config.num_layers, | |
num_heads=config.num_heads, | |
d_model=config.d_model, | |
qkv_bias=config.qkv_bias, | |
proj_bias=config.proj_bias, | |
qk_norm=config.qk_norm, | |
use_mup=config.use_mup, | |
attn_drop=config.attn_drop, | |
mlp_ratio=config.mlp_ratio, | |
mlp_bias=config.mlp_bias, | |
mlp_drop=config.mlp_drop, | |
action_processing=config.action_network, | |
random_dummy_action=config.random_dummy_action, | |
jointly_predict_actions=config.jointly_predict_actions, | |
mask_token_id=config.image_vocab_size | |
) | |
# learnable embedding for the maximum image sizes | |
self.pos_embed_TSC = torch.nn.Parameter(torch.zeros(1, config.T, config.S + config.action_token_size, config.d_model)) | |
print(f"{self.h=} {self.w=} {config.S=} {config.T=} {config.d_model=}") | |
self.mask_token_id = config.image_vocab_size | |
self.seq_len = config.S | |
self.relevant_action_mask = None | |
self.token_embed = FactorizedEmbedding( # also works for num_factored_vocabs = 1 | |
factored_vocab_size=config.factored_vocab_size, | |
num_factored_vocabs=config.num_factored_vocabs, | |
d_model=config.d_model, | |
mask_token_id=self.mask_token_id, | |
) | |
cls = FixedMuReadout if config.use_mup else nn.Linear # (Fixed)MuReadout might slow dow down compiled training? | |
self.out_x_proj = cls(config.d_model, config.factored_vocab_size * config.num_factored_vocabs) | |
self.config = config | |
self.action_mask_tokens = torch.nn.Parameter(torch.zeros(1, config.T, 1, config.d_model)) | |
if (self.config.init_actions or self.config.use_actions) and self.config.action_domains is not None: | |
self.init_action_projectors(self.config.action_domains, self.config.d_actions, | |
self.config.action_stats, self.config.action_network) | |
def init_action_projectors( | |
self, | |
domains: list[str], | |
d_actions: list[int], | |
action_stats: list[list[list[float]]], | |
action_network: str = "mlp", | |
use_diffusion: bool = False, | |
): | |
# initialize the action stems. It's called externally for training. | |
# assert len(domains) == len(d_actions) | |
self.config.init_actions = True | |
self.config.action_domains = domains | |
self.config.d_actions = d_actions | |
self.config.action_stats = action_stats | |
self.action_preprocessor = nn.ModuleDict() | |
self.action_mlp = nn.ModuleDict() | |
self.action_out_projectors = nn.ModuleDict() | |
# initialize for every layer | |
print("use diffusion: ", use_diffusion) | |
print("init action network:", action_network) | |
cls = FixedMuReadout if self.config.use_mup else nn.Linear # (Fixed)MuReadout might slow dow down compiled training? | |
# We currently skip datasets if they fail but `domains` is all specified datasets, so we get misalignment in this case | |
assert len(domains) == len(d_actions) == len(action_stats), f"{len(domains)=} {len(d_actions)=} {len(action_stats)=}" | |
for domain, d_action, action_stat in zip(domains, d_actions, action_stats): | |
# by default, we share these modules across layers | |
self.action_preprocessor[domain] = ActionStat(action_stat) | |
self.action_mlp[domain] = BasicMLP(d_action, self.config.d_model) | |
if not use_diffusion: | |
self.action_out_projectors[domain] = cls(self.config.d_model, d_action) | |
# by default, the conditioning are separate for each layer | |
for layer in self.decoder.layers: | |
layer.action_projectors = nn.ModuleDict() | |
for domain, d_action, action_stat in zip(domains, d_actions, action_stats): | |
if "mlp" in action_network: | |
layer.action_projectors[domain] = nn.Identity() | |
elif "cross_attention" in action_network: | |
layer.action_projectors[domain] = BasicCrossAttention( | |
num_heads=8, | |
d_model=self.config.d_model, | |
k_model=d_action | |
) | |
elif "modulate" in action_network: | |
layer.action_projectors[domain] = ModulateLayer(self.config.d_model, self.config.d_model) | |
def generate( | |
self, | |
input_ids: torch.LongTensor, | |
attention_mask: torch.LongTensor, | |
max_new_tokens: int, | |
min_new_tokens: int = None, | |
return_logits: bool = False, | |
return_with_actions: bool = False, | |
maskgit_steps: int = 1, | |
temperature: float = 0.0, | |
action_ids: torch.Tensor = None, | |
domain: str = "default", | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
""" | |
Args designed to match the format of Llama. | |
We ignore `attention_mask`, and use `max_new_tokens` to determine the number of frames to generate. | |
Returns: `(sample_THW, factored_logits)` if `return_logits` else `sample_THW` | |
sample_THW: size (B, num_new_frames * H * W) corresponding to autoregressively generated | |
unfactorized token ids for future frames. | |
Optionally, factored_logits: size (B, factored_vocab_size, num_factored_vocabs, num_new_frames, H, W). | |
""" | |
assert min_new_tokens in (None, max_new_tokens), \ | |
"Expecting `min_new_tokens`, if specified, to match `max_new_tokens`." | |
# assert max_new_tokens % self.config.S == 0, "Expecting `max_new_tokens` to be a multiple of `self.config.S`." | |
h, w = self.h, self.w | |
if "h" in kwargs: | |
h = kwargs["h"][0] | |
if "w" in kwargs: | |
w = kwargs["w"][0] | |
S = h * w | |
num_new_frames = max_new_tokens // S | |
inputs_THW = rearrange(input_ids.clone(), "b (t h w) -> b t h w", h=h, w=w) | |
inputs_masked_THW = torch.cat([ | |
inputs_THW, | |
torch.full((input_ids.size(0), num_new_frames, h, w), | |
self.mask_token_id, dtype=torch.long, device=input_ids.device) | |
], dim=1) | |
all_factored_logits = [] | |
for timestep in range(inputs_THW.size(1), inputs_THW.size(1) + num_new_frames): | |
# could change sampling hparams | |
sample_HW, factored_logits, actions = self.maskgit_generate( | |
inputs_masked_THW, | |
timestep, | |
maskgit_steps=maskgit_steps, | |
temperature=temperature, | |
action_ids=action_ids, | |
domain=domain, | |
**kwargs | |
) | |
inputs_masked_THW[:, timestep] = sample_HW | |
all_factored_logits.append(factored_logits) | |
predicted_tokens = rearrange(inputs_masked_THW, "B T H W -> B (T H W)") | |
if return_with_actions: | |
# unnormalize actions | |
actions = self.action_preprocessor[domain[0]].unnormalize(actions) | |
return predicted_tokens, actions | |
elif return_logits: | |
return predicted_tokens, torch.stack(all_factored_logits, dim=3) | |
else: | |
return predicted_tokens | |
def init_mask(self, prompt_THW, t=1): | |
# since we generate 1 image at a time, the mask should be for a single frame, not across all frames. | |
T, H, W = prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) | |
# self.seq_len | |
unmasked = torch.zeros(prompt_THW.size(0), t * self.seq_len, dtype=torch.bool, device=prompt_THW.device) | |
return unmasked | |
def maskgit_generate( | |
self, | |
prompt_THW: torch.LongTensor, | |
out_t: int, | |
maskgit_steps: int = 1, | |
temperature: float = 0.0, | |
unmask_mode: str = "random", | |
action_ids=None, | |
domain="default", | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
""" | |
Performs MaskGIT-style inference to predict frame `out_t`. | |
Args: | |
prompt_THW: Unfactorized token ids, size (B, T, H, W) | |
out_t: Will return predicted unfactorized token ids for this frame. | |
Should be >= 1 as the 0th frame is assumed to be given. | |
Expects all future frames to be fully masked. | |
maskgit_steps: The number of MaskGIT-style inference steps to take. | |
temperature: Sampling temperature. | |
In the factorized case, sampling is performed for each factorized vocabulary independently. | |
If temperature is <= 1e-8, will be greedy (i.e. argmax) instead of actual sampling. | |
unmask_mode: The method to determine tokens to unmask during each step of MaskGIT inference. | |
Options: | |
- "greedy" for unmasking the most confident tokens, which is matches the original MaskGIT | |
- "random" for randomly choosing tokens to unmask | |
"greedy" tends to copy the previous frame, so we default to "random" instead. | |
Returns: (sample_HW, factored_logits) | |
sample_HW: size (B, H, W) corresponding to predicted unfactorized token ids for frame `out_t`. | |
factored_logits: size (B, factored_vocab_size, num_factored_vocabs, H, W). | |
""" | |
# assume we have pre-masked z{out_t}...zT with all masks | |
assert out_t, "maskgit_generate requires out_t > 0" | |
assert torch.all(prompt_THW[:, out_t:] == self.mask_token_id), \ | |
f"when generating z{out_t}, frames {out_t} and later must be masked" | |
bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) | |
S = h * w | |
# this will be modified in place on each iteration of this loop | |
unmasked = self.init_mask(prompt_THW) | |
logits_CTHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
logits_CHW = logits_CTHW[:, :, out_t] | |
orig_logits_CHW = logits_CHW.clone() | |
# Return these original logits, not logits after partially sampling. | |
for step in range(maskgit_steps): | |
# Perform a single maskgit step (cosine schedule), updating unmasked in-place | |
if step > 0: | |
# recompute logits with updated prompt | |
# action is one step out so this line is doing it again. | |
logits_CHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
logits_CHW = logits_CHW[:, :, out_t] | |
factored_logits = rearrange(logits_CHW, "b (num_vocabs vocab_size) h w -> b vocab_size num_vocabs h w", | |
vocab_size=self.config.factored_vocab_size, | |
num_vocabs=self.config.num_factored_vocabs) | |
factored_probs = torch.nn.functional.softmax(factored_logits, dim=1) | |
samples_HW = torch.zeros((bs, h, w), dtype=torch.long, device=prompt_THW.device) | |
confidences_HW = torch.ones((bs, h, w), dtype=torch.float, device=prompt_THW.device) | |
for probs in factored_probs.flip(2).unbind(2): | |
if temperature <= 1e-8: # greedy sampling | |
sample = probs.argmax(dim=1) | |
else: | |
# Categorical expects last dim to be channel dim | |
dist = torch.distributions.categorical.Categorical( | |
probs=rearrange(probs, "b vocab_size ... -> b ... vocab_size") / temperature | |
) | |
sample = dist.sample() | |
samples_HW *= self.config.factored_vocab_size | |
samples_HW += sample | |
confidences_HW *= torch.gather(probs, 1, sample.unsqueeze(1)).squeeze(1) | |
prev_unmasked = unmasked.clone() | |
prev_img_flat = rearrange(prompt_THW[:, out_t], "B H W -> B (H W)") | |
samples_flat = samples_HW.reshape(bs, S) | |
if step != maskgit_steps - 1: # skip masking for last maskgit step | |
# use cosine mask scheduling function, n is how many of frame out_t to mask | |
n = math.ceil(cosine_schedule((step + 1) / maskgit_steps) * S) | |
if unmask_mode == "greedy": | |
# set the n patches with the least confidence to mask_token | |
confidences_flat = confidences_HW.reshape(bs, S) | |
elif unmask_mode == "random": | |
# randomize confidences, so that patches are randomly masked | |
confidences_flat = torch.rand_like(confidences_HW).reshape(bs, S) | |
# not probability distribution anymore, but only relative order matters | |
else: | |
raise NotImplementedError(f"Expected `unmask_mode` to be one of ['greedy', 'random'], " | |
f"got {unmask_mode}") | |
confidences_flat[unmasked] = torch.inf | |
least_confident_tokens = torch.argsort(confidences_flat, dim=1) | |
# unmask the (self.config.S - n) most confident tokens | |
unmasked.scatter_(1, least_confident_tokens[:, n:], True) | |
samples_flat.scatter_(1, least_confident_tokens[:, :n], self.mask_token_id) | |
# copy previously unmasked values from prompt input into sample | |
samples_flat[prev_unmasked] = prev_img_flat[prev_unmasked] | |
samples_HW = samples_flat.reshape(-1, h, w) | |
# feed back to iteratively decode | |
prompt_THW[:, out_t] = samples_HW | |
# Return the final sample and logits | |
return samples_HW, rearrange( | |
orig_logits_CHW, "B (num_vocabs vocab_size) H W -> B vocab_size num_vocabs H W", | |
vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs, H=h, W=w | |
), action_outputs | |
def maskgit_generate_horizon( | |
self, | |
prompt_THW: torch.LongTensor, | |
out_t_min: int, | |
out_t_max: int, | |
maskgit_steps: int = 1, | |
temperature: float = 0.0, | |
unmask_mode: str = "random", | |
action_ids=None, | |
domain="default", | |
skip_normalization: bool = False, | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
""" | |
Performs MaskGIT-style inference to predict frame `out_t`. | |
Args: | |
prompt_THW: Unfactorized token ids, size (B, T, H, W) | |
out_t: Will return predicted unfactorized token ids for this frame. | |
Should be >= 1 as the 0th frame is assumed to be given. | |
Expects all future frames to be fully masked. | |
maskgit_steps: The number of MaskGIT-style inference steps to take. | |
temperature: Sampling temperature. | |
In the factorized case, sampling is performed for each factorized vocabulary independently. | |
If temperature is <= 1e-8, will be greedy (i.e. argmax) instead of actual sampling. | |
unmask_mode: The method to determine tokens to unmask during each step of MaskGIT inference. | |
Options: | |
- "greedy" for unmasking the most confident tokens, which is matches the original MaskGIT | |
- "random" for randomly choosing tokens to unmask | |
"greedy" tends to copy the previous frame, so we default to "random" instead. | |
Returns: (sample_HW, factored_logits) | |
sample_HW: size (B, H, W) corresponding to predicted unfactorized token ids for frame `out_t`. | |
factored_logits: size (B, factored_vocab_size, num_factored_vocabs, H, W). | |
""" | |
# assume we have pre-masked z{out_t}...zT with all masks | |
assert out_t, "maskgit_generate requires out_t > 0" | |
assert torch.all(prompt_THW[:, out_t:] == self.mask_token_id), \ | |
f"when generating z{out_t}, frames {out_t} and later must be masked" | |
bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) | |
S = h * w | |
# this will be modified in place on each iteration of this loop | |
unmasked = self.init_mask(prompt_THW) | |
logits_CTHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
logits_CHW = logits_CTHW[:, :, out_t_min:out_t_max] | |
orig_logits_CHW = logits_CHW.clone() | |
# Return these original logits, not logits after partially sampling. | |
for step in tqdm(range(maskgit_steps)): | |
# Perform a single maskgit step (cosine schedule), updating unmasked in-place | |
if step > 0: | |
# recompute logits with updated prompt | |
# action is one step out so this line is doing it again. | |
logits_CHW, action_outputs = self.compute_logits(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
logits_CHW = logits_CHW[:, :, out_t_min:out_t_max] | |
factored_logits = rearrange(logits_CHW, "b (num_vocabs vocab_size) h w -> b vocab_size num_vocabs h w", | |
vocab_size=self.config.factored_vocab_size, | |
num_vocabs=self.config.num_factored_vocabs) | |
factored_probs = torch.nn.functional.softmax(factored_logits, dim=1) | |
samples_HW = torch.zeros((bs, h, w), dtype=torch.long, device=prompt_THW.device) | |
confidences_HW = torch.ones((bs, h, w), dtype=torch.float, device=prompt_THW.device) | |
for probs in factored_probs.flip(2).unbind(2): | |
if temperature <= 1e-8: # greedy sampling | |
sample = probs.argmax(dim=1) | |
else: | |
# Categorical expects last dim to be channel dim | |
dist = torch.distributions.categorical.Categorical( | |
probs=rearrange(probs, "b vocab_size ... -> b ... vocab_size") / temperature | |
) | |
sample = dist.sample() | |
samples_HW *= self.config.factored_vocab_size | |
samples_HW += sample | |
confidences_HW *= torch.gather(probs, 1, sample.unsqueeze(1)).squeeze(1) | |
prev_unmasked = unmasked.clone() | |
prev_img_flat = rearrange(prompt_THW[:, out_t_min:out_t_max], "B H W -> B (H W)") | |
samples_flat = samples_HW.reshape(bs, S) | |
if step != maskgit_steps - 1: # skip masking for last maskgit step | |
# use cosine mask scheduling function, n is how many of frame out_t to mask | |
n = math.ceil(cosine_schedule((step + 1) / maskgit_steps) * S) | |
if unmask_mode == "greedy": | |
# set the n patches with the least confidence to mask_token | |
confidences_flat = confidences_HW.reshape(bs, S) | |
elif unmask_mode == "random": | |
# randomize confidences, so that patches are randomly masked | |
confidences_flat = torch.rand_like(confidences_HW).reshape(bs, S) | |
# not probability distribution anymore, but only relative order matters | |
else: | |
raise NotImplementedError(f"Expected `unmask_mode` to be one of ['greedy', 'random'], " | |
f"got {unmask_mode}") | |
confidences_flat[unmasked] = torch.inf | |
least_confident_tokens = torch.argsort(confidences_flat, dim=1) | |
# unmask the (self.config.S - n) most confident tokens | |
unmasked.scatter_(1, least_confident_tokens[:, n:], True) | |
samples_flat.scatter_(1, least_confident_tokens[:, :n], self.mask_token_id) | |
# copy previously unmasked values from prompt input into sample | |
samples_flat[prev_unmasked] = prev_img_flat[prev_unmasked] | |
samples_HW = samples_flat.reshape(-1, h, w) | |
# feed back to iteratively decode | |
prompt_THW[:, out_t_min:out_t_max] = samples_HW | |
# Return the final sample and logits | |
return samples_HW, rearrange( | |
orig_logits_CHW, "B (num_vocabs vocab_size) H W -> B vocab_size num_vocabs H W", | |
vocab_size=self.config.factored_vocab_size, num_vocabs=self.config.num_factored_vocabs, H=h, W=w | |
), action_outputs | |
def compute_video_loss_and_acc(self, logits_CTHW, targets_THW, relevant_mask_THW): | |
# Video token prediction | |
T, H, W = self.config.T, self.h, self.w | |
targets_THW = targets_THW.clone() | |
targets_THW = rearrange(targets_THW, "B (T H W) -> B T H W", T=T, H=H, W=W) | |
logits_CTHW, targets_THW = logits_CTHW[:, :, 1:], targets_THW[:, 1:] # first frame always unmasked | |
factored_logits = rearrange(logits_CTHW, | |
"b (num_vocabs vocab_size) t h w -> b vocab_size num_vocabs t h w", | |
vocab_size=self.config.factored_vocab_size, | |
num_vocabs=self.config.num_factored_vocabs) | |
factored_targets = factorize_labels(targets_THW) | |
# adding label_smoothing | |
loss_THW = F.cross_entropy(factored_logits, factored_targets, reduction="none", label_smoothing=0.01).sum(dim=1) | |
acc_THW = (factored_logits.argmax(dim=1) == factored_targets).all(dim=1) | |
# Compute the mean masked error. | |
# Multiply loss values by mask instead of indexing them, more computationally efficient. | |
num_masked_tokens = torch.sum(relevant_mask_THW) | |
relevant_loss = torch.sum(loss_THW * relevant_mask_THW) / num_masked_tokens | |
relevant_acc = torch.sum(acc_THW * relevant_mask_THW).float() / num_masked_tokens | |
# only optimize on the masked/noised logits. | |
return relevant_loss, relevant_acc | |
def compute_logits(self, x_THW: torch.Tensor, action_ids: torch.Tensor = None, domain = None, **kwargs): | |
# x_THW is for z0,...,zT while x_targets is z1,...,zT | |
h, w = self.h, self.w | |
if "h" in kwargs: | |
assert "w" in kwargs | |
h = kwargs["h"][0] | |
w = kwargs["w"][0] | |
x_TS = rearrange(x_THW, "B T H W -> B T (H W)") | |
x_TSC = self.token_embed(x_TS) | |
T = x_TSC.shape[1] | |
if action_ids is not None: | |
# currently, action_preprocessor just normalizes the actions | |
skip_normalization = kwargs.get("skip_normalization", False) | |
if not skip_normalization: | |
action_ids = self.action_preprocessor[domain[0]](action_ids) | |
action_ids = self.action_mlp[domain[0]](action_ids) # [B, T, D] | |
if "concat" in self.config.action_network: | |
# randomly dropped the conditioning | |
action_condition = action_ids[:, :T, None].repeat(1, 1, self.config.action_token_size, 1) # [B, T, S, D] | |
if self.relevant_action_mask is not None and self.config.jointly_predict_actions: | |
action_condition = self.relevant_action_mask[:, :T] * self.action_mask_tokens[:, :T] + \ | |
(1 - self.relevant_action_mask[:, :T]) * action_condition[:, :T] | |
x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, D] | |
elif self.config.jointly_predict_actions: | |
# all masked there is no input actions and try to predict actions as in policy | |
action_condition = self.action_mask_tokens[:, :T].repeat(1, 1, self.config.action_token_size, 1) | |
x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, D] | |
# additive position embeddings, using the same vocab space | |
domain = domain[0] if domain is not None else None | |
x_TSC = self.decoder(x_TSC + self.pos_embed_TSC[:, :x_TSC.shape[1], :x_TSC.shape[2]], action_ids=action_ids, domain=domain) | |
decoded_actions = None | |
decoded_states = None | |
if self.config.jointly_predict_actions: | |
decoded_actions = x_TSC[:, :, -self.config.action_token_size:].mean(dim=2) # pool all tokens | |
decoded_actions = self.action_out_projectors[domain](decoded_actions) | |
if self.config.jointly_predict_states: | |
x_TSC = x_TSC[:, :, :h*w] # remove action tokens | |
x_next_TSC = self.out_x_proj(x_TSC) | |
decoded_states = rearrange(x_next_TSC, "B T (H W) C -> B C T H W", H=h, W=w) | |
# break into actions here | |
return decoded_states, decoded_actions | |
def forward(self, input_ids, labels, action_ids=None, domain="default", **kwargs): | |
""" | |
input_ids: size (B, T * H * W) represents video sequences | |
labels: size (B, T * H * W) represents video sequences | |
action_ids: size (B, T, Da) represents action sequences | |
""" | |
# if h and w in kwargs, update them. support varying resolutions. | |
T, H, W = self.config.T, self.h, self.w | |
if "h" in kwargs: | |
H = kwargs["h"][0] | |
if "w" in kwargs: | |
W = kwargs["w"][0] | |
x_THW = rearrange(input_ids, "B (T H W) -> B T H W", T=T, H=H, W=W) | |
if action_ids is not None: | |
action_labels = action_ids.clone() | |
# in training we add masked tokens between 0 (fully unmasked as in video pred) and 1 (fully masked as in policies) for training losses | |
action_mask = torch.zeros_like(action_ids) | |
# while action_mask.sum() == 0 or action_mask.sum() == action_mask.numel(): | |
drop_ratio = torch.rand(len(action_ids), 1, 1) | |
action_mask = torch.rand(len(action_ids), T, 1) < drop_ratio | |
self.relevant_action_mask = action_mask.unsqueeze(-1).cuda().to(x_THW.dtype) | |
# Record the loss over masked tokens only to make it more comparable to LLM baselines | |
logits_CTHW, action_outputs = self.compute_logits(x_THW, action_ids=action_ids, domain=domain, **kwargs) | |
relevant_mask = x_THW[:, 1:] == self.mask_token_id # could also get mask of corrupted tokens by uncommenting line in `get_maskgit_collator` | |
relevant_loss = torch.zeros(1).to(x_THW.device) | |
relevant_acc = torch.zeros(1).to(x_THW.device) | |
if logits_CTHW is not None: | |
relevant_loss, relevant_acc = self.compute_video_loss_and_acc(logits_CTHW, labels, relevant_mask) | |
if action_outputs is not None: | |
action_loss = torch.nn.functional.mse_loss(action_labels, action_outputs, reduce="none") | |
action_loss = (action_loss * self.relevant_action_mask[...,0]).mean() | |
return ModelOutput( loss=relevant_loss, | |
acc=relevant_acc, | |
logits=logits_CTHW, | |
action_loss=action_loss, | |
actions=action_outputs) | |
return ModelOutput(loss=relevant_loss, acc=relevant_acc, logits=logits_CTHW) | |
def init_weights(self): | |
""" Works with and without muP. """ | |
std = 0.02 | |
for module in self.modules(): | |
if isinstance(module, nn.Linear): | |
if hasattr(module.weight, "infshape"): # muP | |
mup.normal_(module.weight, mean=0.0, std=std) | |
else: | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
def set_mup_shapes(self, rescale_params=False): | |
base_config = self.config.shallow_copy() | |
base_config.num_heads = 8 | |
base_config.d_model = 256 # currently hardcoding to this shape | |
base_model = STMaskGIT(base_config) | |
mup.set_base_shapes(self, base_model, rescale_params=rescale_params) | |
def from_pretrained(cls, *args, **kwargs): | |
""" Extra logic for muP. """ | |
model = super().from_pretrained(*args, **kwargs) | |
if model.config.use_mup: | |
model.set_mup_shapes(rescale_params=False) | |
return model | |
class FixedMuReadout(mup.MuReadout): | |
# add init_weights for FixedMuReadout | |
def __init__(self, d_input, d_output): | |
super().__init__(d_input, d_output) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): # TODO: muP? | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight, gain=0.01) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x): | |
""" | |
Using `return super(mup.MuReadout, self).forward(self.output_mult * x / self.width_mult())` with `torch.compile` | |
results in two divisions by `self.width_mult()` for some reason | |
""" | |
# return F.linear(self.output_mult * x / self.width_mult(), self.weight, self.bias) # equivalent | |
return nn.Linear.forward(self, self.output_mult * x / self.width_mult()) | |