Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn, Tensor | |
from einops import rearrange | |
import torch | |
from genie.attention import SelfAttention | |
import numpy as np | |
from typing import Optional | |
class Mlp(nn.Module): | |
def __init__( | |
self, | |
d_model: int, | |
mlp_ratio: float = 4.0, | |
mlp_bias: bool = True, | |
mlp_drop: float = 0.0, | |
) -> None: | |
super().__init__() | |
hidden_dim = int(d_model * mlp_ratio) | |
self.fc1 = nn.Linear(d_model, hidden_dim, bias=mlp_bias) | |
self.act = nn.GELU() | |
self.fc2 = nn.Linear(hidden_dim, d_model, bias=mlp_bias) | |
self.drop = nn.Dropout(mlp_drop) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.drop(self.act(self.fc1(x))) | |
x = self.drop(self.fc2(x)) | |
return x | |
class STBlock(nn.Module): | |
# See Figure 4 of https://arxiv.org/pdf/2402.15391.pdf | |
def __init__( | |
self, | |
num_heads: int, | |
d_model: int, | |
qkv_bias: bool = False, | |
proj_bias: bool = True, | |
qk_norm: bool = True, | |
use_mup: bool = True, | |
attn_drop: float = 0.05, # add dropout | |
mlp_ratio: float = 4.0, | |
mlp_bias: bool = True, | |
mlp_drop: float = 0.05, | |
# action relevant | |
action_processing: str = "mlp", | |
jointly_predict_actions: bool = False, | |
mask_token_id: int = 0 | |
) -> None: | |
super().__init__() | |
self.norm1 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05) | |
# sequence dim is over each frame's 16x16 patch tokens | |
self.spatial_attn = SelfAttention( | |
num_heads=num_heads, | |
d_model=d_model, | |
qkv_bias=qkv_bias, | |
proj_bias=proj_bias, | |
qk_norm=qk_norm, | |
use_mup=use_mup, | |
attn_drop=attn_drop, | |
) | |
# sequence dim is over time sequence (16) | |
self.temporal_attn = SelfAttention( | |
num_heads=num_heads, | |
d_model=d_model, | |
qkv_bias=qkv_bias, | |
proj_bias=proj_bias, | |
qk_norm=qk_norm, | |
use_mup=use_mup, | |
attn_drop=attn_drop, | |
) | |
self.action_prediction = jointly_predict_actions | |
self.action_processing = action_processing | |
self.norm2 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05) | |
self.mlp = Mlp(d_model=d_model, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, mlp_drop=mlp_drop) | |
self.action_projectors = None # set at run-time | |
def forward(self, x_TSC: Tensor, action_ids: Tensor = None, domain = None) -> Tensor: | |
""" | |
The main forward pass of the STBlock. It does action conditioning (with options), | |
(bidrectional) spatial attention, (causal) temporal attention, and action masking. | |
""" | |
T, S = x_TSC.size(1), x_TSC.size(2) | |
x_SC = rearrange(x_TSC, 'B T S C -> (B T) S C') | |
x_SC = x_SC + self.spatial_attn(self.norm1(x_SC)) | |
# Process attention temporally | |
x_TC = rearrange(x_SC, '(B T) S C -> (B S) T C', T=T) | |
if action_ids is not None and domain is not None and self.action_projectors is not None: | |
# action_ids: [B, T, D]. Only apply to video parts | |
if "mlp" in self.action_processing: | |
action_ids = self.action_projectors[domain](action_ids) # does not depend on x_TC | |
x_TC = rearrange(x_TC, '(B S) T C -> B S T C', S=S) | |
x_TC = x_TC + action_ids[:, None, :x_TC.shape[2]] # expand across spatial | |
x_TC = rearrange(x_TC, 'B S T C -> (B S) T C', S=S) | |
elif "cross_attention" in self.action_processing: | |
x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids, action_ids) | |
elif "modulate" in self.action_processing: | |
try: | |
x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids) | |
except: | |
import IPython; IPython.embed() | |
# Apply the Causal Transformer | |
x_TC = x_TC + self.temporal_attn(x_TC, causal=True) # [256, 16, 256] | |
x_TC = x_TC + self.mlp(self.norm2(x_TC)) | |
x_TSC = rearrange(x_TC, '(B S) T C -> B T S C', S=S) | |
return x_TSC | |
class STTransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
num_layers: int, | |
num_heads: int, | |
d_model: int, | |
qkv_bias: bool = False, | |
proj_bias: bool = True, | |
qk_norm: bool = True, | |
use_mup: bool = True, | |
attn_drop: float = 0.0, | |
mlp_ratio: float = 4.0, | |
mlp_bias: bool = True, | |
mlp_drop: float = 0.0, | |
# action relevant | |
action_processing: str = "mlp", | |
jointly_predict_actions: bool = False, | |
random_dummy_action: bool = True, | |
mask_token_id: int = 0, | |
): | |
super().__init__() | |
self.layers = nn.ModuleList([STBlock( | |
num_heads=num_heads, | |
d_model=d_model, | |
qkv_bias=qkv_bias, | |
proj_bias=proj_bias, | |
qk_norm=qk_norm, | |
use_mup=use_mup, | |
attn_drop=attn_drop, | |
mlp_ratio=mlp_ratio, | |
mlp_bias=mlp_bias, | |
mlp_drop=mlp_drop, | |
action_processing=action_processing, | |
jointly_predict_actions=jointly_predict_actions, | |
mask_token_id=mask_token_id | |
) for _ in range(num_layers)]) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
""" | |
Weight initialization for transformer | |
""" | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight, gain=0.1) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, tgt, action_ids=None, domain=""): | |
x = tgt | |
for layer in self.layers: | |
x = layer(x, action_ids=action_ids, domain=domain) | |
return x | |