Spaces:
Running
on
Zero
Running
on
Zero
#https://github.com/LTH14/mar/tree/main/ | |
from functools import partial | |
import numpy as np | |
from tqdm import tqdm | |
import scipy.stats as stats | |
import math | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
import mup | |
from genie.config import DiffusionGenieConfig | |
from .diffloss import DiffLoss | |
from .st_mask_git import STMaskGIT | |
from transformers.utils import ModelOutput | |
def mask_by_order(mask_len, order, bsz, seq_len): | |
masking = torch.zeros(bsz, seq_len).cuda() | |
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool() | |
return masking | |
class FixedMuReadout(mup.MuReadout): | |
def forward(self, x): | |
""" | |
Using `return super(mup.MuReadout, self).forward(self.output_mult * x / self.width_mult())` with `torch.compile` | |
results in two divisions by `self.width_mult()` for some reason | |
""" | |
# return F.linear(self.output_mult * x / self.width_mult(), self.weight, self.bias) # equivalent | |
return nn.Linear.forward(self, self.output_mult * x / self.width_mult()) | |
class STMAR(STMaskGIT): | |
""" Spatial-Time MAR with VisionTransformer backbone | |
""" | |
def __init__(self, config: DiffusionGenieConfig): | |
self.diffloss_w = config.diffloss_w | |
self.diffloss_d = config.diffloss_d | |
self.num_sampling_steps = config.num_sampling_steps | |
self.grad_checkpointing = config.grad_checkpointing | |
# -------------------------------------------------------------------------- | |
# VAE and patchify specifics | |
self.patch_size = config.patch_size | |
self.vae_stride = config.vae_stride | |
self.buffer_size = config.buffer_size | |
self.vae_embed_dim = config.vae_embed_dim | |
self.maskgit_steps = config.maskgit_steps | |
super().__init__(config) | |
# -------------------------------------------------------------------------- | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.vae_embed_dim)) | |
self.token_embed = nn.Linear(config.vae_embed_dim * self.config.patch_size ** 2, config.d_model, bias=False) # hard coded | |
cls = FixedMuReadout if config.use_mup else nn.Linear # (Fixed)MuReadout might slow dow down compiled training? | |
self.out_x_proj = cls(config.d_model, config.d_model) | |
self.decoder_norm = nn.LayerNorm(config.d_model, eps=1e-6) | |
self.z_proj_ln = nn.LayerNorm(config.d_model, eps=1e-6) | |
self.seq_len = config.S // (self.config.patch_size ** 2) | |
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len * config.T, config.d_model)) | |
# -------------------------------------------------------------------------- | |
# Diffusion Loss | |
self.diffloss = DiffLoss( | |
target_channels=config.vae_embed_dim * self.config.patch_size ** 2, | |
z_channels=config.d_model, | |
width=config.diffloss_w, | |
depth=config.diffloss_d, | |
num_sampling_steps=config.num_sampling_steps, | |
grad_checkpointing=config.grad_checkpointing | |
) | |
# print(self.config.init_actions, self.config.use_actions, self.config.action_domains is not None) | |
self.diffusion_batch_mul = config.diffusion_batch_mul | |
self.initialize_weights() | |
def init_action_projectors( | |
self, | |
domains: list[str], | |
d_actions: list[int], | |
action_stats: list[list[list[float]]], | |
action_network: str = "mlp", | |
): | |
super().init_action_projectors(domains, d_actions, action_stats, action_network, use_diffusion=True) | |
self.action_diff_losses = nn.ModuleDict() | |
# action heads are heterogeneous | |
for domain, d_action in zip(self.config.action_domains, self.config.d_actions): | |
self.action_diff_losses[domain] = DiffLoss( | |
target_channels=d_action, | |
z_channels=self.config.d_model, | |
width=self.diffloss_w, | |
depth=self.diffloss_d, | |
num_sampling_steps=self.num_sampling_steps, | |
grad_checkpointing=self.grad_checkpointing | |
) | |
def initialize_weights(self): | |
# initialize nn.Linear and nn.LayerNorm parameters | |
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) | |
self.init_weights() | |
def set_mup_shapes(self, rescale_params=False): | |
base_config = self.config.shallow_copy() | |
base_config.num_heads = 8 | |
base_config.d_model = 256 # currently hardcoding to this shape | |
base_model = STMAR(base_config) | |
if hasattr(self, "action_preprocessor"): | |
for base_layer, layer in zip(base_model.decoder.layers, self.decoder.layers): | |
base_layer.action_projectors = layer.action_projectors | |
base_model.action_preprocessor = self.action_preprocessor | |
mup.set_base_shapes(self, base_model, rescale_params=rescale_params) | |
def compute_action_loss_and_acc(self, z, target, domain, mask = None): | |
bsz, seq_len, *_ = target.shape | |
# not so sure if this repeated is needed | |
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
if mask is not None: | |
mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) | |
loss = self.action_diff_losses[domain[0]](z=z, target=target, mask=mask) # | |
acc = torch.zeros_like(loss) | |
return loss, acc | |
def compute_video_loss_and_acc(self, z, target, mask = None): | |
z = rearrange(z, "B C T H W -> B (T H W) C").float() | |
target = rearrange(target, "B T H W C -> B (T H W) C").float() | |
bsz, seq_len, *_ = target.shape | |
target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) | |
if mask is not None: | |
mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) | |
loss = self.diffloss(z=z, target=target, mask=mask) # no need for | |
acc = torch.zeros_like(loss) | |
return loss, acc | |
def compute_latents(self, x_THW, action_ids: torch.Tensor = None, domain=None, action_mask=None, **kwargs): | |
# x_THW is for z0,...,zT while x_targets is z1,...,zT | |
pos_embed_TSC = self.pos_embed_TSC | |
diffusion_pos_embed_learned = self.diffusion_pos_embed_learned | |
b, t, h, w, c = x_THW.shape | |
x_TSC = rearrange(x_THW, "B T H W C -> B T (H W) C").float() | |
x_TSC = self.token_embed(x_TSC) | |
T = x_TSC.shape[1] | |
if action_ids is not None: | |
# currently, action_preprocessor just normalizes the actions | |
skip_normalization = kwargs.get("skip_normalization", False) | |
if not skip_normalization: | |
action_ids = self.action_preprocessor[domain[0]](action_ids) | |
action_ids = self.action_mlp[domain[0]](action_ids) # [B, T, D] | |
if "concat" in self.config.action_network: | |
# randomly dropped the conditioning | |
if self.config.action_network == "resampler_concat": | |
action_condition = self.action_projectors[domain[0]](action_ids[:, :T]) | |
else: | |
action_condition = action_ids[:, :T, None].repeat(1, 1, self.config.action_token_size, 1) # [B, T, S, C] | |
# we add masked tokens between 0 (fully unmasked as in video pred) and 1 (fully masked as in policies) for training losses | |
# if we have actions and are trying to predict actions | |
# if self.config.jointly_predict_actions and action_mask is not None: | |
# action_condition = action_mask * self.action_mask_tokens[:, :T] + (1 - action_mask) * action_condition | |
x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, C] | |
elif self.config.jointly_predict_actions: | |
# all masked when predicting actions and there is no input actions | |
action_condition = self.action_mask_tokens[:, :T].repeat(1, 1, self.config.action_token_size, 1) | |
x_TSC = torch.concat((x_TSC, action_condition), dim=2) # [B, T, S, C] | |
x_TSC = self.z_proj_ln(x_TSC + pos_embed_TSC[:, :x_TSC.shape[1], :x_TSC.shape[2]]) | |
# additive position embeddings, using the same vocab space | |
domain = domain[0] if domain is not None else None | |
x_TSC = self.decoder(x_TSC, action_ids=action_ids, domain=domain) | |
# dummy if are not used | |
decoded_states = rearrange(diffusion_pos_embed_learned, "B (T H W) C -> B C T H W", T=self.config.T, H=h, W=w) | |
decoded_actions = None | |
if self.config.jointly_predict_actions: | |
decoded_actions = x_TSC[:, :, -self.config.action_token_size:].mean(dim=2) # pool all tokens | |
# if self.config.jointly_predict_states: | |
x_TSC = x_TSC[:, :, :h*w] # remove action tokens | |
x_next_TSC = self.decoder_norm(self.out_x_proj(x_TSC)) | |
x_next_TSC = x_next_TSC + diffusion_pos_embed_learned.view(1, self.config.T, h*w, self.config.d_model)[:,:T] | |
decoded_states = rearrange(x_next_TSC, "B T (H W) C -> B C T H W", H=h, W=w) | |
return decoded_states, decoded_actions | |
def patchify(self, x): | |
bsz, t, h, w, c = x.shape | |
p = self.patch_size | |
h_, w_ = h // p, w // p | |
x = x.reshape(bsz, t, h_, p, w_, p, c) | |
x = torch.einsum('nthpwqc->nthwpqc', x) | |
x = x.reshape(bsz, t, h_, w_, c * p ** 2) | |
return x | |
def unpatchify(self, x): | |
# input: B T H W C | |
p = self.patch_size | |
bsz, t, h, w, _ = x.shape | |
c = self.vae_embed_dim | |
x = x.reshape(bsz, t, h, w, p, p, c) | |
x = torch.einsum('nthwpqc->nthpwqc', x) | |
x = x.reshape(bsz, t, h * p, w * p, c) | |
return x | |
def forward(self, input_ids, labels, action_ids=None, domain="default", **kwargs): | |
assert "masked_tokens_indicator" in kwargs | |
relevant_mask = kwargs["masked_tokens_indicator"] | |
# class embed | |
T, H, W = self.config.T, self.h, self.w | |
if "h" in kwargs: | |
H = kwargs["h"][0] | |
if "w" in kwargs: | |
W = kwargs["w"][0] | |
x_THW = rearrange(input_ids, "B (T H W) C -> B T H W C", T=T, H=H, W=W) | |
action_mask = None | |
if action_ids is not None and self.config.jointly_predict_actions: | |
action_labels = action_ids.clone() | |
action_mask = torch.zeros(len(action_ids), T, 1) | |
random_timesteps = torch.randint(0, T, (len(action_ids), 1), device=action_ids.device) | |
# Set all timesteps from the sampled t to T to 1 | |
for i, t in enumerate(random_timesteps): | |
action_mask[i, t:] = 1 | |
# Move the mask to the same device and dtype as x_THW if needed | |
action_mask = action_mask.unsqueeze(-1).cuda().to(x_THW.dtype) | |
# change masked token id -> masked token latents | |
x_THW[relevant_mask] = self.mask_token | |
x_THW = self.patchify(x_THW) | |
latents_CTHW, action_outputs = self.compute_latents(x_THW, action_ids=action_ids, domain=domain, action_mask=action_mask, **kwargs) | |
labels = rearrange(labels, "B (T H W) C -> B T H W C", T=T, H=H, W=W) | |
labels = self.patchify(labels) | |
relevant_loss = torch.zeros(1).to(x_THW.device) | |
relevant_acc = torch.zeros(1).to(x_THW.device) | |
relevant_mask = self.patchify(relevant_mask[...,None]).sum(-1) > 0 # as long as it's not no mask | |
# Record the loss over masked tokens only to make it more comparable to LLM baselines | |
if self.config.jointly_predict_states: | |
# could also get mask of corrupted tokens by uncommenting line in `get_maskgit_collator` | |
relevant_loss, relevant_acc = self.compute_video_loss_and_acc(latents_CTHW, labels, relevant_mask) # relevant_mask | |
# compute the action losses | |
if action_outputs is not None: | |
action_loss, _ = self.compute_action_loss_and_acc(action_outputs, action_labels, domain, action_mask) | |
return ModelOutput(loss=relevant_loss, acc=relevant_acc, logits=latents_CTHW, action_loss=action_loss, actions=action_outputs) | |
return ModelOutput(loss=relevant_loss, acc=relevant_acc, logits=latents_CTHW) | |
def generate( | |
self, | |
input_ids: torch.LongTensor, | |
attention_mask: torch.LongTensor, | |
max_new_tokens: int, | |
min_new_tokens: int = None, | |
return_logits: int = False, | |
return_with_actions: bool = False, | |
temperature: float = 1.0, | |
action_ids: torch.Tensor = None, | |
domain: str = "default", | |
action_only: bool = False, | |
state_only: bool = False, | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
""" | |
Args designed to match the format of Llama. | |
We ignore `attention_mask`, and use `max_new_tokens` to determine the number of frames to generate. | |
Returns: `(sample_THW, factored_logits)` if `return_logits` else `sample_THW` | |
sample_THW: size (B, num_new_frames * H * W) corresponding to autoregressively generated | |
unfactorized token ids for future frames. | |
Optionally, factored_logits: size (B, factored_vocab_size, num_factored_vocabs, num_new_frames, H, W). | |
""" | |
assert min_new_tokens in (None, max_new_tokens), \ | |
"Expecting `min_new_tokens`, if specified, to match `max_new_tokens`." | |
# assert max_new_tokens % self.config.S == 0, "Expecting `max_new_tokens` to be a multiple of `self.config.S`." | |
h, w, c = self.h, self.w, self.vae_embed_dim | |
if "h" in kwargs: | |
h = kwargs["h"][0] | |
if "w" in kwargs: | |
w = kwargs["w"][0] | |
S = h*w | |
num_new_frames = max_new_tokens // S | |
inputs_THW = rearrange(input_ids.clone(), "b (t h w) c-> b t h w c", h=h, w=w) | |
inputs_masked_THW = torch.cat([ | |
inputs_THW, | |
self.mask_token[None, None].repeat(inputs_THW.size(0), num_new_frames, h, w, 1) | |
], dim=1) | |
all_factored_logits = [] | |
for timestep in range(inputs_THW.size(1), inputs_THW.size(1) + num_new_frames): | |
# could change sampling hparams | |
sample_HW, factored_logits, actions = self.maskgit_generate( | |
inputs_masked_THW, | |
timestep, | |
maskgit_steps=self.maskgit_steps, | |
temperature=temperature, | |
action_ids=action_ids, | |
domain=domain, | |
action_only=action_only, | |
state_only=state_only, | |
**kwargs | |
) | |
inputs_masked_THW[:, timestep] = sample_HW | |
all_factored_logits.append(factored_logits) | |
predicted_tokens = rearrange(inputs_masked_THW, "B T H W C -> B (T H W) C") | |
if return_with_actions: | |
# unnormalize actions | |
actions = self.action_preprocessor[domain[0]].unnormalize(actions) | |
return predicted_tokens, actions | |
elif return_logits: | |
return predicted_tokens, torch.stack(all_factored_logits, dim=3) # (b, c, num_new_frames, h, w) | |
else: | |
return predicted_tokens | |
def sample_orders(self, bsz): | |
# generate a batch of random generation orders | |
orders = [] | |
for _ in range(bsz): | |
order = np.array(list(range(self.seq_len))) | |
np.random.shuffle(order) | |
orders.append(order) | |
orders = torch.Tensor(np.array(orders)).cuda().long() | |
return orders | |
def maskgit_generate( | |
self, | |
prompt_THW, | |
out_t: int, | |
unmask_mode: str = "random", | |
action_ids=None, | |
domain="default", | |
maskgit_steps=8, | |
cfg=1.0, | |
temperature=1.0, | |
cfg_schedule="linear", | |
action_only: bool = False, | |
state_only: bool = False, | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
# init and sample generation orders | |
assert out_t, "maskgit_generate requires out_t > 0" | |
prompt_THW = self.patchify(prompt_THW) | |
bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) | |
S = h * w | |
orders = self.sample_orders(bs) # random order | |
sampled_action_token_latent = None | |
# this will be modified in place on each iteration of this loop | |
unmasked = self.init_mask(prompt_THW) | |
# patchify the prompt | |
latents_CTHW, action_outputs = self.compute_latents(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
latents_CHW = latents_CTHW[:, :, out_t] | |
orig_latents_CHW = latents_CHW.clone() | |
# Return these original logits, not logits after partially sampling. | |
for step in range(maskgit_steps): | |
# Perform a single maskgit step (cosine schedule), updating unmasked in-place | |
if step > 0: # recompute logits with updated prompt | |
latents_CHW, action_outputs = self.compute_latents(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
latents_CHW = latents_CHW[:, :, out_t] | |
# mask ratio for the next round, following MaskGIT and MAGE. | |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / maskgit_steps) | |
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() | |
# masks out at least one for the next iteration | |
mask_len = torch.maximum(torch.Tensor([1]).cuda(), | |
torch.minimum(torch.sum(~unmasked, dim=-1, keepdims=True) - 1, mask_len)) | |
# get masking for next iteration and locations to be predicted in this iteration | |
mask_next = mask_by_order(mask_len[0], orders, bs, self.seq_len) | |
mask = ~unmasked | |
if step >= maskgit_steps - 1: | |
mask_to_pred = mask[:bs].bool() # last step | |
else: | |
mask_to_pred = torch.logical_xor(mask[:bs].bool(), mask_next.bool()) | |
mask = mask_next | |
if not cfg == 1.0: | |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) | |
# sample token latents for this step | |
latents_CHW = rearrange(latents_CHW, "b c h w -> b (h w) c") | |
latents_CHW = latents_CHW[mask_to_pred.nonzero(as_tuple=True)] | |
# copy previously unmasked values from prompt input into sample | |
# cfg schedule follow Muse | |
total_mask_len = unmasked.shape[1] | |
if cfg_schedule == "linear": | |
cfg_iter = 1 + (cfg - 1) * (total_mask_len - unmasked.sum()) / total_mask_len | |
elif cfg_schedule == "constant": | |
cfg_iter = cfg | |
else: | |
raise NotImplementedError | |
# need to reshape back | |
sampled_token_latent = self.diffloss.sample(latents_CHW.contiguous(), temperature, cfg_iter, clip_denoised=True) | |
if not cfg == 1.0: | |
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples | |
mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) | |
if action_outputs is not None and self.config.jointly_predict_actions: | |
sampled_action_token_latent = self.action_diff_losses[domain[0]].sample(action_outputs.view(-1, action_outputs.shape[-1]), | |
temperature, cfg_iter, clip_denoised=True) | |
if not cfg == 1.0: | |
sampled_action_token_latent, _ = sampled_action_token_latent.chunk(2, dim=0) | |
prompt_THW_reshape = rearrange(prompt_THW, "B T H W C -> B T (H W) C") | |
prompt_THW_reshape[:, out_t][mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent | |
prompt_THW = rearrange(prompt_THW_reshape, "B T (H W) C -> B T H W C", H=h, W=w) | |
# Return the final sample and logits | |
prompt_THW = self.unpatchify(prompt_THW) | |
return prompt_THW[:, out_t], orig_latents_CHW, sampled_action_token_latent | |
def maskgit_generate_horizon( | |
self, | |
prompt_THW, | |
out_t_min: int, | |
out_t_max: int, | |
unmask_mode: str = "random", | |
action_ids=None, | |
domain="default", | |
maskgit_steps=8, | |
cfg=1.0, | |
temperature=1.0, | |
cfg_schedule="linear", | |
**kwargs | |
) -> tuple[torch.LongTensor, torch.FloatTensor]: | |
# init and sample generation orders | |
prompt_THW = self.patchify(prompt_THW) | |
bs, t, h, w = prompt_THW.size(0), prompt_THW.size(1), prompt_THW.size(2), prompt_THW.size(3) | |
S = h * w | |
orders = self.sample_orders(bs) # random order | |
# this will be modified in place on each iteration of this loop | |
horizon = out_t_max - out_t_min | |
unmasked = self.init_mask(prompt_THW, t=horizon) | |
# patchify the prompt | |
latents_CTHW, latents_actions = self.compute_latents(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
latents_CHW = latents_CTHW[:, :, out_t_min:out_t_max] | |
orig_latents_CHW = latents_CHW.clone() | |
# Return these original logits, not logits after partially sampling. | |
seq_len = horizon * self.seq_len | |
for step in (range(maskgit_steps)): | |
# Perform a single maskgit step (cosine schedule), updating unmasked in-place | |
if step > 0: # recompute logits with updated prompt | |
latents_CHW, latents_actions = self.compute_latents(prompt_THW, action_ids=action_ids, domain=domain, **kwargs) | |
latents_CHW = latents_CHW[:, :, out_t_min:out_t_max] | |
# mask ratio for the next round, following MaskGIT and MAGE. | |
mask_ratio = np.cos(math.pi / 2. * (step + 1) / maskgit_steps) | |
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).cuda() | |
# masks out at least one for the next iteration | |
mask_len = torch.maximum(torch.Tensor([1]).cuda(), | |
torch.minimum(torch.sum(~unmasked, dim=-1, keepdims=True) - 1, mask_len)) | |
# get masking for next iteration and locations to be predicted in this iteration | |
mask_next = mask_by_order(mask_len[0], orders, bs, seq_len) | |
mask = ~unmasked | |
if step >= maskgit_steps - 1: | |
mask_to_pred = mask[:bs].bool() # last step | |
else: | |
mask_to_pred = torch.logical_xor(mask[:bs].bool(), mask_next.bool()) | |
mask = mask_next | |
if not cfg == 1.0: | |
mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) | |
# sample token latents for this step | |
latents_CHW = rearrange(latents_CHW, "b c t h w -> b (t h w) c") | |
latents_CHW = latents_CHW[mask_to_pred.nonzero(as_tuple=True)] | |
# copy previously unmasked values from prompt input into sample | |
# cfg schedule follow Muse | |
total_mask_len = unmasked.shape[1] | |
if cfg_schedule == "linear": | |
cfg_iter = 1 + (cfg - 1) * (total_mask_len - unmasked.sum()) / total_mask_len | |
elif cfg_schedule == "constant": | |
cfg_iter = cfg | |
else: | |
raise NotImplementedError | |
# need to reshape back | |
sampled_token_latent = self.diffloss.sample(latents_CHW.contiguous(), temperature, cfg_iter) | |
if not cfg == 1.0: | |
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples | |
mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) | |
if latents_actions is not None and self.config.jointly_predict_actions: | |
action_outputs = self.action_diff_losses[domain[0]].sample(latents_actions.view(-1, latents_actions.shape[-1]), | |
temperature, cfg_iter) | |
if not cfg == 1.0: | |
action_outputs, _ = action_outputs.chunk(2, dim=0) | |
# need to reshape backout_t_max - out_t_min_latent.chunk(2, dim=0) | |
prompt_THW_reshape = rearrange(prompt_THW[:, out_t_min:out_t_max], "B T H W C -> B (T H W) C") | |
prompt_THW_reshape[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent | |
prompt_THW[:, out_t_min:out_t_max] = rearrange(prompt_THW_reshape.clone(), "B (T H W) C -> B T H W C", T=horizon, H=h, W=w) | |
# Return the final sample and logits | |
prompt_THW = self.unpatchify(prompt_THW) | |
return prompt_THW[:, out_t_min:out_t_max], orig_latents_CHW, action_outputs | |