Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn as nn | |
# import torch.nn.functional as F | |
import clip | |
# from model.rotation2xyz import Rotation2xyz | |
import utils.model_utils as model_utils | |
# from model.PointNet2 import PointnetPP | |
# from model.DGCNN import PrimitiveNet | |
# from utils.anchor_utils import masking_load_driver, anchor_load_driver ### load driver; masking load driver # | |
# import os | |
### MDM 10 ### | |
class MDMV10(nn.Module): | |
def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, | |
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, | |
ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, | |
arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): | |
super().__init__() | |
self.legacy = legacy | |
self.modeltype = modeltype | |
self.njoints = njoints | |
self.nfeats = nfeats | |
self.num_actions = num_actions | |
self.data_rep = data_rep | |
self.dataset = dataset | |
self.pose_rep = pose_rep | |
self.glob = glob | |
self.glob_rot = glob_rot | |
self.translation = translation | |
self.latent_dim = latent_dim | |
self.ff_size = ff_size | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.ablation = ablation | |
self.activation = activation | |
self.clip_dim = clip_dim | |
self.action_emb = kargs.get('action_emb', None) | |
self.input_feats = self.njoints * self.nfeats | |
self.normalize_output = kargs.get('normalize_encoder_output', False) | |
self.cond_mode = kargs.get('cond_mode', 'no_cond') | |
self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) # | |
### GET args ### | |
self.args = kargs.get('args', None) | |
### GET the diff. suit ### | |
self.diff_jts = self.args.diff_jts | |
self.diff_basejtsrel = self.args.diff_basejtsrel | |
self.diff_basejtse = self.args.diff_basejtse | |
self.diff_realbasejtsrel = self.args.diff_realbasejtsrel | |
self.diff_realbasejtsrel_to_joints = self.args.diff_realbasejtsrel_to_joints | |
### GET the diff. suit ### | |
self.arch = arch | |
## ==== gru_emb_dim ==== ## # gru emb dim # | |
self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0 | |
# joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # | |
###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### | |
self.use_anchors = self.args.use_anchors | |
# if self.use_anchors: # use anchors # anchor_load_driver, masking_load_driver # | |
# # anchor_load_driver, masking_load_driver # | |
# inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ## | |
# fvi, aw, _, _ = anchor_load_driver(inpath) | |
# self.face_vertex_index = torch.from_numpy(fvi).long() | |
# self.anchor_weight = torch.from_numpy(aw).float() | |
# anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor") | |
# palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt") | |
# hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path) | |
# # self.hand_palm_vertex_mask for hand palm mask # | |
# self.hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool() ## the mask for hand palm to get hand anchors # | |
# self.nn_anchors = int(self.hand_palm_vertex_mask.float().sum()) #### number of anchors here ### | |
# self.joints_feats_in_dim = 21 * 3 | |
# joints feats in dim # | |
self.nn_keypoints = 21 | |
# if self.args.use_anchors: | |
# # self.nn_keypoints = self.nn_anchors # nn_anchors # | |
# self.nn_keypoints = 32 # nn_anchors # | |
self.joints_feats_in_dim = self.nn_keypoints * 3 | |
self.data_rep = "xyz" | |
if self.diff_jts: | |
## Input process for joints ## | |
self.joint_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
# # InputProcessObjBase(self, data_rep, input_feats, latent_dim) | |
# self.joint_sequence_input_process = InputProcessObjBase(self.data_rep, 3, self.latent_dim) | |
# self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim) | |
self.joint_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
# if self.arch == 'trans_enc': | |
print("TRANS_ENC init") ## transformer encoder layer ## UNet | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
## Joint sequence transformer encoder layer ## # sequence encoder # | |
self.joint_sequence_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
### logvar for the encoding laeyer and | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
## Joint sequence transformer encoder layer ## # sequence encoder # | |
self.joint_sequence_logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, | |
num_layers=self.num_layers) | |
# elif self.arch == 'trans_dec': | |
# print("TRANS_DEC init") | |
# seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, # num_heads | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=activation) | |
# self.joint_sequence_seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, | |
# num_layers=self.num_layers) | |
# elif self.arch == 'gru': | |
# print("GRU init") | |
# self.joint_sequence_gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) | |
# else: | |
# raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') | |
### joint sequence embed timestep ## ## timestep | |
self.joint_sequence_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) | |
# self.joint_sequence_output_process = OutputProcess(self.data_rep, self.latent_dim) | |
# (self, data_rep, input_feats, latent_dim, njoints, nfeats): | |
#### ====== joint sequence denoising block ====== #### | |
## seqTransEncoder ## | |
self.joint_sequence_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) | |
self.joint_sequence_denoising_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: | |
self.joint_sequence_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# seq_len x bsz x dim | |
if self.args.const_noise: | |
# 1) max pool latents over the sequence | |
# 2) transform the pooled latnets via the linear layer | |
self.glb_denoising_latents_trans_layer = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
nn.Linear(self.latent_dim * 2, self.latent_dim) | |
) | |
# refinement for predicted joints # --> not in the paradigm of generation # | |
# seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=self.activation) | |
# self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
# num_layers=self.num_layers) | |
#### ====== joint sequence denoiisng block ====== #### | |
### Output process ### output proces for joint sequence ### # output proces --> datarep, joints feats in dim, latent dim ## | |
###### joints_feats_in_dim ###### | |
self.joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
# OutputProcessCond | |
# self.joint_sequence_output_process = OutputProcessCond(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### | |
# real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process | |
if self.diff_realbasejtsrel_to_joints: # feature for each joint point? --> for the denoising purpose # | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
layernorm = True | |
self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions | |
if self.args.use_abs_jts_for_encoding_obj_base: | |
self.rel_input_feats = 21 * (3) | |
# layernorm = False | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
# self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
# elif self.args.use | |
else: | |
if self.args.use_objbase_v2: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v3: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV3(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
else: | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
if self.args.use_abs_jts_for_encoding: # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process | |
self.real_basejtsrel_to_joints_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
self.real_basejtsrel_to_joints_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
### Encoding layer ### # InputProcessObjBaseV2 | |
real_basejtsrel_to_joints_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_to_joints_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_seqTransEncoderLayer, # basejtsrel_seqTrans | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### | |
self.real_basejtsrel_to_joints_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_to_joints_sequence_pos_encoder) | |
self.real_basejtsrel_to_joints_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.real_basejtsrel_to_joints_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
real_basejtsrel_to_joints_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_to_joints_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_denoising_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
self.real_basejtsrel_to_joints_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
# OutputProcessCond | |
if self.diff_realbasejtsrel: | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
# self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions | |
# self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim) | |
self.rel_input_feats = self.nn_keypoints * (3 + 3 + 3) | |
layernorm = True | |
if self.args.use_objbase_v2: | |
self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, glb_feats_trans=True) | |
elif self.args.use_objbase_v4: # use_objbase_out_v4 | |
self.rel_input_feats = (self.args.nn_base_pts * (3 + 3 + 3)) # current joint positions # how to keep the dimension | |
self.real_basejtsrel_input_process = InputProcessObjBaseV4(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v5: # use_objbase_v5, | |
if self.args.v5_in_not_base: | |
self.rel_input_feats = (self.nn_keypoints * 3) | |
elif self.args.v5_in_not_base_pos: | |
self.rel_input_feats = 3 + (self.nn_keypoints * 3) | |
else: | |
self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb) | |
elif self.args.use_objbase_v6: # real_basejtsrel_input_process | |
self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) + 3 | |
self.real_basejtsrel_input_process = InputProcessObjBaseV6(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v7: | |
# InputProcessObjBaseV7 | |
self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV7(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
else: | |
self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
self.real_basejtsrel_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
### Encoding layer ### | |
real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # # dropout # # dropout | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_seqTransEncoderLayer, # basejtsrel_seqTrans | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### | |
self.real_basejtsrel_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_sequence_pos_encoder) | |
self.real_basejtsrel_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.real_basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
real_basejtsrel_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_denoising_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
print(f"not_cond_base: {self.args.not_cond_base}, latent_dim: {self.latent_dim}") | |
if self.args.use_jts_pert_realbasejtsrel: | |
print(f"use_jts_pert_realbasejtsrel!!!!!!") | |
self.real_basejtsrel_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) | |
else: | |
if self.args.use_objbase_out_v3: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV3(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
elif self.args.use_objbase_out_v4: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV4(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
elif self.args.use_objbase_out_v5: # use_objbase_v5, use_objbase_out_v5 | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV5(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, out_objbase_v5_bundle_out=self.args.out_objbase_v5_bundle_out, v5_out_not_cond_base=self.args.v5_out_not_cond_base, nn_keypoints=self.nn_keypoints) | |
else: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
# OutputProcessCond | |
if self.diff_basejtsrel: | |
# treate them as textures of signals to model # # base pts -> dec on base pts features --> | |
# latent space denoising and feature decoding --> a little bit concern about the feature decoding process # | |
# TODO: add base_pts and base_normals to the base points -rel-to- rhand joints encoding process # | |
self.rel_input_feats = self.nn_keypoints * (3 + 3 + 3) # relative positions from base pts to rhand joints ## | |
# self.avg_joints_sequence_input_process, self.avg_joint_sequence_output_process # | |
self.avg_joints_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
if self.args.with_glb_info: | |
# InputProcessWithGlbInfo | |
self.joints_offset_input_process = InputProcessWithGlbInfo(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
else: | |
self.joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
if self.args.not_cond_base: | |
self.rel_input_feats = self.nn_keypoints * ( 3) | |
# self.input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats+self.gru_emb_dim, self.latent_dim) | |
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
### Encoding layer ### | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
### Encoding layer ### | |
# logvar_seqTransEncoder_e, logvar_seqTransEncoder # logvar_seqTranEncoder | |
seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### | |
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # baseptsrel # | |
self.basejtsrel_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
self.sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# seq_len x bsz x dim | |
if self.args.const_noise: # add to attention network # | |
# 1) max pool latents over the sequence | |
# 2) transform the pooled latnets via the linear layer | |
self.basejtsrel_glb_denoising_latents_trans_layer = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
nn.Linear(self.latent_dim * 2, self.latent_dim) | |
) | |
###### joints_feats_in_dim ###### # a linear transformation net with weights and bias set to zero # | |
self.avg_joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) # output avgjts sequence | |
# OutputProcessCond | |
self.joint_offset_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) | |
if self.args.use_dec_rel_v2: | |
self.output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
else: | |
# OutputProcessObjBaseRaw ## output process for basejtsrel # | |
self.output_process = OutputProcessObjBaseRaw(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
##### ==== input process, communications, output process for rel, dists ==== ##### | |
if self.diff_basejtse: | |
### input process obj base ### | |
# construct input_process_e # | |
# self.input_feats_e = 21 * (3 + 3 + 3 + 1 + 1) | |
self.input_feats_e = self.nn_keypoints * (3 + 3 + 1 + 1) | |
self.input_process_e = InputProcessObjBase(self.data_rep, self.input_feats_e+self.gru_emb_dim, self.latent_dim) | |
self.sequence_pos_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
# # single layer transformers # ## predict relative position for each base point? # existing model | |
# if self.arch == 'trans_enc': | |
print("TRANS_ENC init") | |
seqTransEncoderLayer_e = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e, | |
num_layers=self.num_layers) | |
print("TRANS_ENC init") | |
# logvar_seqTransEncoder_e, | |
seqTransEncoderLayer_e_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.logvar_seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e_logvar, | |
num_layers=self.num_layers) | |
# | |
# elif self.arch == 'trans_dec': | |
# print("TRANS_DEC init") | |
# seqTransDecoderLayer_e = nn.TransformerDecoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=activation) | |
# self.seqTransDecoder_e = nn.TransformerDecoder(seqTransDecoderLayer_e, | |
# num_layers=self.num_layers) | |
# elif self.arch == 'gru': ## arch ## | |
# print("GRU init") | |
# self.gru_e = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) | |
# else: | |
# raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') | |
# tiemstep # # timestep embedding e # Embed timestep e # | |
self.embed_timestep_e = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) | |
self.sequence_pos_denoising_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # | |
self.basejtse_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.basejtse_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) ### basejtse_denoising_seqTransEncoder ### | |
else: | |
basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=self.activation) | |
# self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, | |
# num_layers=self.num_layers) | |
# seq_len x bsz x dim | |
if self.args.const_noise: | |
# 1) max pool latents over the sequence | |
# 2) transform the pooled latnets via the linear layer | |
self.basejtse_denoising_seqTransEncoder = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
nn.Linear(self.latent_dim * 2, self.latent_dim) | |
) | |
# self.output_process_e = OutputProcessObjBaseV3(self.data_rep, self.latent_dim) | |
self.output_process_e = OutputProcessObjBaseERaw(self.data_rep, self.latent_dim) | |
# self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) | |
def set_enc_to_eval(self): | |
# jts: joint_sequence_input_process, joint_sequence_pos_encoder, joint_sequence_seqTransEncoder, joint_sequence_logvar_seqTransEncoder | |
# basejtsrel: input_process, sequence_pos_encoder, seqTransEncoder, logvar_seqTransEncoder, | |
# basejtse: input_process_e, sequence_pos_encoder_e, seqTransEncoder_e, logvar_seqTransEncoder_e | |
if self.diff_jts: | |
self.joint_sequence_input_process.eval() | |
self.joint_sequence_pos_encoder.eval() | |
self.joint_sequence_seqTransEncoder.eval() | |
self.joint_sequence_logvar_seqTransEncoder.eval() | |
if self.diff_basejtse: | |
self.input_process_e.eval() | |
self.sequence_pos_encoder_e.eval() | |
self.seqTransEncoder_e.eval() | |
self.logvar_seqTransEncoder_e.eval() | |
if self.diff_basejtsrel: | |
self.input_process.eval() | |
self.sequence_pos_encoder.eval() | |
self.seqTransEncoder.eval() # seqTransEncoder, logvar_seqTransEncoder | |
self.logvar_seqTransEncoder.eval() | |
def set_bn_to_eval(self): | |
if self.args.use_objbase_v6: # real_basejtsrel_input_process | |
try: | |
self.real_basejtsrel_input_process.pnpp_conv_net.set_bn_no_training() | |
except: | |
pass | |
def parameters_wo_clip(self): | |
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
def load_and_freeze_clip(self, clip_version): | |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
jit=False) # Must set jit=False for training | |
clip.model.convert_weights( # encode # ours float | |
clip_model) # Actually this line is unnecessary since clip by default already on float16 ### ours | |
# Freeze CLIP weights | |
clip_model.eval() | |
for p in clip_model.parameters(): | |
p.requires_grad = False | |
return clip_model | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_mask_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond | |
return cond * (1. - mask) | |
else: | |
return cond | |
def encode_text(self, raw_text): | |
# raw_text - list (batch_size length) of strings with input text prompts # | |
device = next(self.parameters()).device | |
max_text_len = 20 if self.dataset in ['humanml', 'kit', 'motion_ours'] else None # Specific hardcoding for humanml dataset | |
if max_text_len is not None: | |
default_context_length = 77 | |
context_length = max_text_len + 2 # start_token + 20 + end_token | |
assert context_length < default_context_length | |
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate | |
# print('texts', texts.shape) | |
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) | |
texts = torch.cat([texts, zero_pad], dim=1) | |
# print('texts after pad', texts.shape, texts) | |
else: ## | |
texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate | |
return self.clip_model.encode_text(texts).float() | |
def dec_jts_only_fr_latents(self, latents_feats): | |
joint_seq_output = self.joint_sequence_output_process(latents_feats) # [bs, njoints, nfeats, nframes] | |
# bsz x ws x nnj x nnfeats # --> joints_seq_outputs # | |
joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() | |
## joints seq outputs ## | |
diff_jts_dict = { | |
"joint_seq_output": joint_seq_output, | |
"joints_seq_latents": latents_feats, | |
} | |
return diff_jts_dict | |
def dec_basejtsrel_only_fr_latents(self, latent_feats, x): | |
# basejtsrel_seq_latents_pred_feats | |
avg_jts_seq_latents = latent_feats[0:1, ...] | |
other_basejtsrel_seq_latents = latent_feats[1:, ...] | |
avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
basejtsrel_dec_out = { | |
'avg_jts_outputs': avg_jts_outputs, | |
'basejtsrel_output': basejtsrel_output['dec_rel'], | |
} | |
return basejtsrel_dec_out | |
# def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): | |
def dec_latents_to_joints_with_t(self, input_latent_feats, x, timesteps): | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
# joints_seq_latents: seq x bs x d --> perturbed joitns_seq_latents \in [-1, 1] ## | |
# def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): | |
## positional encoding for denoising ## | |
# rt_dict = { | |
# 'joint_seq_output': joint_seq_output, | |
# 'rel_base_pts_outputs': rel_base_pts_outputs, | |
# } | |
rt_dict = {} | |
if self.diff_jts: | |
####### input latent feats ####### | |
joints_seq_latents = input_latent_feats["joints_seq_latents"] | |
if not self.args.without_dec_pos_emb: | |
joints_seq_latents = self.joint_sequence_denoising_pos_encoder(joints_seq_latents) | |
# ### GET joints seq time embeddings ### ### embed time stamps ### | |
# joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
# joints_seq_latents = torch.cat( | |
# [joints_seq_time_emb, joints_seq_latents], dim=0 | |
# ) | |
# joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # seq x bs x d | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
joints_seq_time_emb = joints_seq_time_emb.repeat(joints_seq_latents.size(0), 1, 1).contiguous() | |
joints_seq_latents = joints_seq_latents + joints_seq_time_emb | |
if self.args.use_ours_transformer_enc: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
joints_seq_latents = joints_seq_latents.permute(1, 0, 2) | |
else: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
joints_seq_latents = torch.cat( | |
[joints_seq_time_emb, joints_seq_latents], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
joints_seq_latents = joints_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] | |
# joints_seq_latents: seq_len x bsz x latent_dim # | |
if self.args.const_noise: | |
seq_len = joints_seq_latents.size(0) | |
# if self.args.const_noise: | |
joints_seq_latents, _ = torch.max(joints_seq_latents, dim=0, keepdim=True) | |
joints_seq_latents = self.glb_denoising_latents_trans_layer(joints_seq_latents) # seq_len x bsz x latent_dim | |
joints_seq_latents = joints_seq_latents.repeat(seq_len, 1, 1).contiguous() | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
joints_seq_latents = input_latent_feats["joints_seq_latents_enc"] | |
# bsz x ws x nnj x 3 # | |
joint_seq_output = self.joint_sequence_output_process(joints_seq_latents) # [bs, njoints, nfeats, nframes] | |
# bsz x ws x nnj x nnfeats # --> joints_seq_outputs # | |
joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() | |
diff_jts_dict = { | |
"joint_seq_output": joint_seq_output, | |
"joints_seq_latents": joints_seq_latents, | |
} | |
else: | |
diff_jts_dict = {} | |
if self.diff_basejtsrel: | |
rel_base_pts_outputs = input_latent_feats["rel_base_pts_outputs"] | |
if rel_base_pts_outputs.size(0) == 1 and self.args.single_frame_noise: | |
rel_base_pts_outputs = rel_base_pts_outputs.repeat(self.args.window_size + 1, 1, 1) | |
if not self.args.without_dec_pos_emb: # without | |
avg_jts_inputs = rel_base_pts_outputs[0:1, ...] | |
other_rel_base_pts_outputs = rel_base_pts_outputs[1: , ...] | |
other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) | |
rel_base_pts_outputs = torch.cat( | |
[avg_jts_inputs, other_rel_base_pts_outputs], dim=0 | |
) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs.size(0), 1, 1).contiguous() | |
basejtsrel_seq_latents = rel_base_pts_outputs + basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_seq_latents = torch.cat( | |
[basejtsrel_time_emb, rel_base_pts_outputs], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
basejtsrel_seq_latents = input_latent_feats["rel_base_pts_outputs_enc"] | |
if basejtsrel_seq_latents.size(0) == 1 and self.args.single_frame_noise: | |
basejtsrel_seq_latents = basejtsrel_seq_latents.repeat(self.args.window_size + 1, 1, 1) | |
basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents | |
elif self.args.pred_diff_noise: | |
basejtsrel_seq_latents_pred_feats = input_latent_feats["rel_base_pts_outputs"] - basejtsrel_seq_latents | |
else: | |
basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
### GET joints seq output ### | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # | |
# basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
# basejtsrel_seq_latents = torch.cat( | |
# [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 | |
# ) | |
# basejtsrel_seq_latents_pred_feats | |
avg_jts_seq_latents = basejtsrel_seq_latents_pred_feats[0:1, ...] | |
other_basejtsrel_seq_latents = basejtsrel_seq_latents_pred_feats[1:, ...] | |
avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
#### | |
# avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] | |
# other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] | |
# avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
# avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
diff_basejtsrel_dict = { | |
"basejtsrel_output": basejtsrel_output['dec_rel'], | |
"basejtsrel_seq_latents": basejtsrel_seq_latents, | |
"avg_jts_outputs": avg_jts_outputs, | |
} | |
else: | |
diff_basejtsrel_dict = {} | |
if self.diff_basejtse: | |
# e_disp_rel_to_base_along_normals = input_latent_feats['e_disp_rel_to_base_along_normals'] | |
# e_disp_rel_to_baes_vt_normals = input_latent_feats['e_disp_rel_to_baes_vt_normals'] | |
base_jts_e_feats = input_latent_feats['base_jts_e_feats'] # seq x bs x d --> e feats | |
if not self.args.without_dec_pos_emb: | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() | |
base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) | |
else: | |
#### Decode e_along_normals and e_vt_normals #### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_feats = torch.cat( | |
[base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] | |
# base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
#### Decode e_along_normals and e_vt_normals #### | |
### bae jts e embeddings feats ### | |
# base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
# base_jts_e_feats = torch.cat( | |
# [base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
# ) | |
# base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
base_jts_e_output = self.output_process_e(base_jts_e_feats, x) | |
# dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # | |
dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] | |
dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] | |
dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() | |
dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() | |
#### Decode e_along_normals and e_vt_normals #### | |
diff_basejtse_dict = { | |
'dec_e_along_normals': dec_e_along_normals, | |
'dec_e_vt_normals': dec_e_vt_normals, | |
'base_jts_e_feats': base_jts_e_feats, | |
} | |
else: | |
diff_basejtse_dict = {} | |
rt_dict = {} | |
rt_dict.update(diff_jts_dict) | |
rt_dict.update(diff_basejtsrel_dict) | |
rt_dict.update(diff_basejtse_dict) | |
### rt_dict --> rt_dict of joints, rel ### | |
return rt_dict | |
# return joint_seq_output, joints_seq_latents | |
def reparameterization(self, val_mean, val_var): | |
val_noise = torch.randn_like(val_mean) | |
val_sampled = val_mean + val_noise * val_var ### sample the value | |
if self.args.rnd_noise: | |
val_sampled = val_noise | |
return val_sampled | |
def decode_realbasejtsrel_from_objbasefeats(self, objbasefeats, input_data): | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
objbasefeats, input_data | |
) | |
# real_dec_basejtsrel -> decoded realtive positions # | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() | |
return real_dec_basejtsrel | |
def denoising_realbasejtsrel_objbasefeats(self, pert_obj_base_pts_feats, timesteps): | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(pert_obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = pert_obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq des | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, pert_obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# bsz, nframes, nnb, nnj, 3 --> | |
# | |
# real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
# real_basejtsrel_seq_latents, x | |
# ) | |
return real_basejtsrel_seq_latents | |
def forward(self, x, timesteps): | |
""" | |
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper | |
timesteps: [batch_size] (int) | |
""" | |
# joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # | |
# bsz, nframes, nnj = x['pert_rhand_joints'].shape[:3] | |
# pert_rhand_joints = x['pert_rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws | |
bsz, nframes, nnj = x['rhand_joints'].shape[:3] | |
pert_rhand_joints = x['rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws | |
base_pts = x['base_pts'] ### bsz x nnb x 3 ### | |
base_normals = x['base_normals'] ### bsz x nnb x 3 ### --> base normals ### | |
# base_normals # ## | |
rt_dict = {} | |
## # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
if self.diff_basejtse: | |
### Embed physicss quantities ### | |
# e_disp_rel_to_base_along_normals: bsz x (ws - 1) x nnj x nnb # | |
# e_disp_rel_to_baes_vt_normals: bsz x (ws - 1) x nnj x nnb # | |
# e_disp_rel_to_base_along_normals = x['e_disp_rel_to_base_along_normals'] | |
# e_disp_rel_to_baes_vt_normals = x['e_disp_rel_to_baes_vt_normals'] | |
e_disp_rel_to_base_along_normals = x['pert_e_disp_rel_to_base_along_normals'] | |
e_disp_rel_to_baes_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] | |
nnb = base_pts.size(1) | |
disp_ws = e_disp_rel_to_base_along_normals.size(1) ### --> base normals ### | |
base_pts_disp_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() | |
base_normals_disp_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() | |
# bsz x (ws - 1) x nnj x nnb x (3 + 3 + 1 + 1) | |
base_pts_normals_e_in_feats = torch.cat( # along normals; # vt normals # | |
[base_pts_disp_exp, base_normals_disp_exp, e_disp_rel_to_base_along_normals.unsqueeze(-1), e_disp_rel_to_baes_vt_normals.unsqueeze(-1)], dim=-1 | |
) | |
base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.permute(0, 1, 3, 2, 4).contiguous() | |
# bsz x (ws - 1) x nnb x (nnj x (xxx feats_dim)) | |
base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.view(bsz, disp_ws, nnb, -1).contiguous() | |
## input process ## | |
base_jts_e_feats = self.input_process_e(base_pts_normals_e_in_feats) | |
base_jts_e_feats = self.sequence_pos_encoder_e(base_jts_e_feats) | |
## seq transformation for e ## # | |
base_jts_e_feats_mean = self.seqTransEncoder_e(base_jts_e_feats) ## mean, mdm_ours ## | |
# print(f"base_jts_e_feats: {base_jts_e_feats.size()}") | |
### Embed physicss quantities ### | |
#### base_jts_e_feats, base_jts_e_feats_mean #### | |
# ## us basejtsefeats for denoising directly ## | |
base_jts_e_feats = base_jts_e_feats_mean | |
if not self.args.without_dec_pos_emb: ## use positional encoding ## | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() | |
base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) | |
else: | |
#### Decode e_along_normals and e_vt_normals #### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_feats = torch.cat( | |
[base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
# ### sequence latents ### | |
# if self.args.train_enc: # trian enc for seq latents ### | |
# base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] | |
# base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
#### Decode e_along_normals and e_vt_normals #### | |
### bae jts e embeddings feats ### | |
# base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
# base_jts_e_feats = torch.cat( | |
# [base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
# ) | |
# base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
##### output_process_e -> output energies ##### | |
base_jts_e_output = self.output_process_e(base_jts_e_feats, x) | |
# dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # | |
dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] | |
dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] | |
dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb | |
dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb | |
#### Decode e_along_normals and e_vt_normals #### | |
diff_basejtse_dict = { | |
'dec_e_along_normals': dec_e_along_normals, | |
'dec_e_vt_normals': dec_e_vt_normals, | |
'base_jts_e_feats': base_jts_e_feats, | |
} | |
# rt_dict['base_jts_e_feats'] = base_jts_e_feats | |
# rt_dict['base_jts_e_feats_mean'] = base_jts_e_feats_mean | |
# rt_dict['base_jts_e_feats_logvar'] = base_jts_e_feats_logvar # log_var # | |
else: | |
diff_basejtse_dict = {} | |
if self.diff_jts: | |
# base_pts_normal | |
### InputProcess ### | |
pert_rhand_joints_trans = pert_rhand_joints.permute(0, 2, 3, 1).contiguous() # bsz x nnj x 3 x ws # | |
rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints_trans) # [seqlen, bs, d] | |
### InputProcessObjBase ### | |
# rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints) | |
### === Encode input joint sequences === ### | |
# bs, njoints, nfeats, nframes = x.shape | |
# rhand_joints_emb = self.joint_sequence_embed_timestep(timesteps) # [1, bs, d] | |
# if self.arch == 'trans_enc': | |
xseq = rhand_joints_feats # [seqlen+1, bs, d] | |
xseq = self.joint_sequence_pos_encoder(xseq) # [seqlen+1, bs, d] | |
joint_seq_output_mean = self.joint_sequence_seqTransEncoder(xseq) # [1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] | |
### calculate logvar, mean, and feats ### | |
joint_seq_output_logvar = self.joint_sequence_logvar_seqTransEncoder(xseq) | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
joint_seq_output_var = torch.exp(joint_seq_output_logvar) # seq x bs x d --> encodeing and decoding | |
## base_jts_e_feats: seqlen x bs x d --> val latents ## | |
joint_seq_output = self.reparameterization(joint_seq_output_mean, joint_seq_output_var) | |
rt_dict['joint_seq_output'] = joint_seq_output | |
# rt_dict['joint_seq_output'] = joint_seq_output_mean | |
rt_dict['joint_seq_output_mean'] = joint_seq_output_mean | |
rt_dict['joint_seq_output_logvar'] = joint_seq_output_logvar | |
if self.args.diff_realbasejtsrel_to_joints: # nframes x nnbase x nnjts x (base pts + base normals + 3) 2) point feature for each point; point feature for; condition on the noisy input for the denoised information | |
# real_basejtsrel_to_joints_input_process, real_basejtsrel_to_joints_sequence_pos_encoder, real_basejtsrel_to_joints_seqTransEncoder | |
# real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process | |
bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_joints_for_jts_pred'].size()[:4] | |
normed_base_pts = x['normed_base_pts'] | |
base_normals = x['base_normals'] | |
pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_joints_for_jts_pred'] # bsz x nf x nnj x nnb x 3 | |
## use_abs_jts_pos --> obj jts pos for the encodingj | |
if self.args.use_abs_jts_pos: ## bsz x nf x nnj x nnb x 3 ## ---> abs jts pos ## | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) | |
# use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process | |
if self.args.use_abs_jts_for_encoding: | |
if not self.args.use_abs_jts_pos: | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 | |
abs_jts = pert_rel_base_pts_to_rhand_joints[..., 0, :] | |
abs_jts = abs_jts.permute(0, 2, 3, 1).contiguous() | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(abs_jts) | |
elif self.args.use_abs_jts_for_encoding_obj_base: | |
if not self.args.use_abs_jts_pos: | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] | |
# obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
# [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], pert_rel_base_pts_to_rhand_joints], dim=-1 | |
# ) | |
obj_base_in_feats = pert_rel_base_pts_to_rhand_joints | |
# --> tnrasform the input feature dim to 21 * 3 here for encoding # | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, 1, -1).contiguous() # | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # | |
else: | |
if self.args.use_objbase_v2: | |
# bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
[normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # | |
# ### real_basejtsrel_to_joints_input_process --> real_basejtsrel_to_joints_input_process --> for the joints and input process ### # obj_base_encoded_feats | |
# obj_base_encoded_feats | |
obj_base_pts_feats_pos_embedding = self.real_basejtsrel_to_joints_sequence_pos_encoder(obj_base_encoded_feats) | |
obj_base_pts_feats = self.real_basejtsrel_to_joints_seqTransEncoder(obj_base_pts_feats_pos_embedding) | |
if self.args.use_sigmoid: | |
obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq des | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
joints_offset_output = self.real_basejtsrel_to_joints_output_process(real_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) | |
if self.args.diff_basejtsrel: | |
diff_basejtsrel_to_joints_dict = { | |
'joints_offset_output_from_rel': joints_offset_output | |
} | |
else: | |
diff_basejtsrel_to_joints_dict = { | |
'joints_offset_output': joints_offset_output | |
} | |
else: | |
diff_basejtsrel_to_joints_dict = {} | |
if self.diff_realbasejtsrel: # real_dec_basejtsrel | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] | |
normed_base_pts = x['normed_base_pts'] | |
base_normals = x['base_normals'] | |
pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 | |
if self.args.use_objbase_v2: | |
# bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.use_objbase_v4: | |
# use_objbase_v4: # use_objbase_out_v4 | |
exp_normed_base_pts = normed_base_pts.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() | |
exp_base_normals = base_normals.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[pert_rel_base_pts_to_rhand_joints, exp_normed_base_pts, exp_base_normals], dim=-1 # bsz x nf x nnj x nnb x (3 + 3 + 3) # -> exp_base_normals | |
) | |
obj_base_in_feats = obj_base_in_feats.view(bsz, nf, nnj, -1).contiguous() | |
elif self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
if self.args.v5_in_not_base: | |
obj_base_in_feats = torch.cat( | |
[ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.v5_in_not_base_pos: | |
obj_base_in_feats = torch.cat( | |
[base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.use_objbase_v6 or self.args.use_objbase_v7: | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
[normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
# obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
# [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
# ) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
# obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
if self.args.use_objbase_v6: | |
normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1) # and repeat for the base pts # | |
obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats, normed_base_pts_exp) | |
else: | |
obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats) # nf x bsz x feat_dim # nf x bsz x nnbasepts x feats_dim # | |
# obj_base_encoded_feats | |
obj_base_pts_feats_pos_embedding = self.real_basejtsrel_sequence_pos_encoder(obj_base_encoded_feats) | |
obj_base_pts_feats = self.real_basejtsrel_seqTransEncoder(obj_base_pts_feats_pos_embedding) | |
if self.args.use_sigmoid: | |
obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. | |
if self.args.train_enc: | |
# basejtsrel_seq | |
# bsz, nframes, nnb, nnj, 3 --> | |
# | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
obj_base_pts_feats, x | |
) | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
if self.args.use_objbase_out_v3: | |
real_dec_basejtsrel = real_dec_basejtsrel | |
else: | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 | |
else: | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
# print(f"timesteps: {timesteps.size()}, obj_base_pts_feats: {obj_base_pts_feats.size()}") | |
if self.args.use_objbase_v5: | |
cur_timesteps = timesteps.unsqueeze(1).repeat(1, nnb).view(-1) | |
else: | |
cur_timesteps = timesteps | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(cur_timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq des | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_seq | |
# bsz, nframes, nnb, nnj, 3 --> | |
# | |
if self.args.use_jts_pert_realbasejtsrel: | |
joints_offset_output = self.real_basejtsrel_output_process(real_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) # bsz x nf x nnj x 3 | |
real_dec_basejtsrel = joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, nnb, 1) | |
# real_dec_basejtsrel = joints_offset_output | |
else: | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
real_basejtsrel_seq_latents, x | |
) | |
# real_dec_basejtsrel -> decoded realtive positions # | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
if self.args.use_objbase_out_v3 or self.args.use_objbase_out_v4 or self.args.use_objbase_out_v5: | |
real_dec_basejtsrel = real_dec_basejtsrel | |
else: | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 | |
diff_realbasejtsrel_out_dict = { | |
'real_dec_basejtsrel': real_dec_basejtsrel, | |
'obj_base_pts_feats': obj_base_pts_feats, | |
} | |
else: | |
diff_realbasejtsrel_out_dict = {} | |
# relative joints encoder; obj pos encoder; obj pos encoder; # penetrations, depth, --- how to use depth for guidance -> and also penetrations # object penetrations # | |
if self.diff_basejtsrel: | |
# joints_offset_sequence --> x['pert_joints_offset_sequence'] | |
joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 | |
joints_offset_sequence = joints_offset_sequence.permute(0, 2, 3, 1).contiguous() | |
joints_offset_feats = self.joints_offset_input_process(joints_offset_sequence) # nf x bsz x dim | |
# rel_base_pts_feats = self.input_process(basejtsrel_enc_in_feats) | |
# sequence_pos_encoder | |
rel_base_pts_feats_pos_embedding = self.sequence_pos_encoder(joints_offset_feats) | |
# print(f"joints_offset_feats: {joints_offset_feats.size()}, rel_base_pts_feats_pos_embedding: {rel_base_pts_feats_pos_embedding.size()}") | |
# outputs rel base jts encoded latents ## | |
# seqTransEncoder, logvar_seqTransEncoder | |
# rel_base_pts_outputs_mean = self.basejtsrel_denoising_seqTransEncoder(rel_base_pts_feats_pos_embedding) | |
# ### calculate logvar, mean, and feats ### | |
# rel_base_pts_outputs_logvar = self.joint_sequence_logvar_seqTransEncoder(rel_base_pts_outputs) | |
if self.args.not_diff_avgjts: # not use diff avgjts ## | |
rel_base_pts_feats = rel_base_pts_feats_pos_embedding | |
# seqTransEncoder, logvar_seqTransEncoder # | |
rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) | |
# print(f"rel_base_pts_outputs_mean 1: {rel_base_pts_outputs_mean.size()}") | |
if not self.args.without_dec_pos_emb: # without dec pos embedding | |
# avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] | |
other_rel_base_pts_outputs = rel_base_pts_outputs_mean # | |
other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) | |
rel_base_pts_outputs_mean = other_rel_base_pts_outputs | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() | |
basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb ### time embeddings and relbaseptsoutputs | |
# print(f"basejtsrel_seq_latents: {basejtsrel_seq_latents.size()}, rel_base_pts_outputs_mean: {rel_base_pts_outputs_mean.size()}, basejtsrel_time_emb: {basejtsrel_time_emb.size()}") | |
if self.args.use_ours_transformer_enc: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_seq_latents = torch.cat( | |
[basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_seq_latents_pred_feats | |
# avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] | |
other_basejtsrel_seq_latents = basejtsrel_seq_latents # [1:, ...] | |
joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) | |
# print(f"joints_offset_output in MDM: {joints_offset_output.size()}, joints_offset_sequence: {joints_offset_sequence.size()}, other_basejtsrel_seq_latents: {other_basejtsrel_seq_latents.size()}") | |
diff_basejtsrel_dict = { | |
'joints_offset_output': joints_offset_output | |
} | |
# rt_dict['joints_offset_output'] = joints_offset_output | |
else: | |
avg_joints_sequence = x['pert_avg_joints_sequence'] | |
avg_joints_sequence_trans = avg_joints_sequence.unsqueeze(-1) | |
avg_joints_feats = self.avg_joints_sequence_input_process(avg_joints_sequence_trans) ## 1 x bsz x dim ### | |
rel_base_pts_feats = torch.cat( # (seq_len + 1) x bsz x dim # | |
[avg_joints_feats, rel_base_pts_feats_pos_embedding], dim=0 ## jrel_base_pts_pos_embedding # | |
) | |
## joints embedding for mean statistics and logvar statistics ## | |
# seqTransEncoder, logvar_seqTransEncoder # | |
rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) | |
if not self.args.without_dec_pos_emb: # without dec pos embedding | |
avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] | |
other_rel_base_pts_outputs = rel_base_pts_outputs_mean[1: , ...] | |
other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) | |
rel_base_pts_outputs_mean = torch.cat( | |
[avg_jts_inputs, other_rel_base_pts_outputs], dim=0 | |
) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() | |
basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb ### time embeddings and relbaseptsoutputs | |
if self.args.use_ours_transformer_enc: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_seq_latents = torch.cat( | |
[basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_seq_latents_pred_feats | |
avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] | |
other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] | |
avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) | |
rt_dict['joints_offset_output'] = joints_offset_output | |
rt_dict['avg_jts_outputs'] = avg_jts_outputs | |
else: | |
diff_basejtsrel_dict = {} | |
rt_dict = {} | |
rt_dict.update(diff_basejtsrel_dict) | |
rt_dict.update(diff_basejtse_dict) ### rt_dict and diff_basejtse | |
rt_dict.update(diff_basejtsrel_to_joints_dict) | |
rt_dict.update(diff_realbasejtsrel_out_dict) ### diff | |
return rt_dict | |
def _apply(self, fn): | |
super()._apply(fn) | |
# self.rot2xyz.smpl_model._apply(fn) | |
def train(self, *args, **kwargs): | |
super().train(*args, **kwargs) | |
# self.rot2xyz.smpl_model.train(*args, **kwargs) | |
### MDM 10 ### | |
class MDMV12(nn.Module): | |
def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, | |
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, | |
ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, | |
arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): | |
super().__init__() | |
self.legacy = legacy | |
self.modeltype = modeltype | |
self.njoints = njoints | |
self.nfeats = nfeats | |
self.num_actions = num_actions | |
self.data_rep = data_rep | |
self.dataset = dataset | |
self.pose_rep = pose_rep | |
self.glob = glob | |
self.glob_rot = glob_rot | |
self.translation = translation | |
self.latent_dim = latent_dim | |
self.ff_size = ff_size | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.ablation = ablation | |
self.activation = activation | |
self.clip_dim = clip_dim | |
self.action_emb = kargs.get('action_emb', None) | |
self.input_feats = self.njoints * self.nfeats | |
self.normalize_output = kargs.get('normalize_encoder_output', False) | |
self.cond_mode = kargs.get('cond_mode', 'no_cond') | |
self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) # | |
### GET args ### | |
self.args = kargs.get('args', None) | |
### GET the diff. suit ### | |
self.diff_jts = self.args.diff_jts | |
self.diff_basejtsrel = self.args.diff_basejtsrel | |
self.diff_basejtse = self.args.diff_basejtse | |
self.diff_realbasejtsrel = self.args.diff_realbasejtsrel | |
self.diff_realbasejtsrel_to_joints = self.args.diff_realbasejtsrel_to_joints | |
### GET the diff. suit ### | |
self.arch = arch | |
## ==== gru_emb_dim ==== ## # gru emb dim # | |
self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0 | |
# | |
# joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # | |
###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### | |
self.joints_feats_in_dim = 21 * 3 | |
self.data_rep = "xyz" | |
if self.diff_jts: | |
## Input process for joints ## | |
self.joint_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
# # InputProcessObjBase(self, data_rep, input_feats, latent_dim) | |
# self.joint_sequence_input_process = InputProcessObjBase(self.data_rep, 3, self.latent_dim) | |
# self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim) | |
self.joint_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
# if self.arch == 'trans_enc': | |
print("TRANS_ENC init") ## transformer encoder layer ## UNet | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
## Joint sequence transformer encoder layer ## # sequence encoder # | |
self.joint_sequence_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
### logvar for the encoding laeyer and | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
## Joint sequence transformer encoder layer ## # sequence encoder # | |
self.joint_sequence_logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, | |
num_layers=self.num_layers) | |
# elif self.arch == 'trans_dec': | |
# print("TRANS_DEC init") | |
# seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, # num_heads | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=activation) | |
# self.joint_sequence_seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, | |
# num_layers=self.num_layers) | |
# elif self.arch == 'gru': | |
# print("GRU init") | |
# self.joint_sequence_gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) | |
# else: | |
# raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') | |
### joint sequence embed timestep ## ## timestep | |
self.joint_sequence_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) | |
# self.joint_sequence_output_process = OutputProcess(self.data_rep, self.latent_dim) | |
# (self, data_rep, input_feats, latent_dim, njoints, nfeats): | |
#### ====== joint sequence denoising block ====== #### | |
## seqTransEncoder ## | |
self.joint_sequence_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) | |
self.joint_sequence_denoising_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: | |
self.joint_sequence_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# seq_len x bsz x dim | |
if self.args.const_noise: | |
# 1) max pool latents over the sequence | |
# 2) transform the pooled latnets via the linear layer | |
self.glb_denoising_latents_trans_layer = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
nn.Linear(self.latent_dim * 2, self.latent_dim) | |
) | |
# refinement for predicted joints # --> not in the paradigm of generation # | |
# seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=self.activation) | |
# self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
# num_layers=self.num_layers) | |
#### ====== joint sequence denoiisng block ====== #### | |
### Output process ### output proces for joint sequence ### # output proces --> datarep, joints feats in dim, latent dim ## | |
###### joints_feats_in_dim ###### | |
self.joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
# OutputProcessCond | |
# self.joint_sequence_output_process = OutputProcessCond(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### | |
# real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process | |
if self.diff_realbasejtsrel_to_joints: # feature for each joint point? --> for the denoising purpose # | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
layernorm = True | |
self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions | |
if self.args.use_abs_jts_for_encoding_obj_base: | |
self.rel_input_feats = 21 * (3) | |
# layernorm = False | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
# self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
# elif self.args.use | |
else: | |
if self.args.use_objbase_v2: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v3: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV3(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
else: | |
self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
if self.args.use_abs_jts_for_encoding: # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process | |
self.real_basejtsrel_to_joints_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
self.real_basejtsrel_to_joints_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
### Encoding layer ### # InputProcessObjBaseV2 | |
real_basejtsrel_to_joints_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_to_joints_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_seqTransEncoderLayer, # basejtsrel_seqTrans | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### | |
self.real_basejtsrel_to_joints_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_to_joints_sequence_pos_encoder) | |
self.real_basejtsrel_to_joints_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.real_basejtsrel_to_joints_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
real_basejtsrel_to_joints_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_to_joints_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_denoising_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
self.real_basejtsrel_to_joints_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
# OutputProcessCond | |
if self.diff_realbasejtsrel: | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions | |
# self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim) | |
layernorm = True | |
if self.args.use_objbase_v2: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, glb_feats_trans=True) | |
elif self.args.use_objbase_v4: # use_objbase_out_v4 | |
self.rel_input_feats = (self.args.nn_base_pts * (3 + 3 + 3)) # current joint positions # how to keep the dimension | |
self.real_basejtsrel_input_process = InputProcessObjBaseV4(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v5: # use_objbase_v5, | |
if self.args.v5_in_not_base: | |
self.rel_input_feats = (21 * 3) | |
elif self.args.v5_in_not_base_pos: | |
self.rel_input_feats = 3 + (21 * 3) | |
else: | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb) | |
elif self.args.use_objbase_v6: # real_basejtsrel_input_process | |
self.rel_input_feats = 3 + 3 + (21 * 3) + 3 | |
self.real_basejtsrel_input_process = InputProcessObjBaseV6(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
elif self.args.use_objbase_v7: | |
# InputProcessObjBaseV7 | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.real_basejtsrel_input_process = InputProcessObjBaseV7(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
else: | |
self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) | |
self.real_basejtsrel_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
### Encoding layer ### | |
real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # # dropout # # dropout | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_seqTransEncoderLayer, # basejtsrel_seqTrans | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### | |
self.real_basejtsrel_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_sequence_pos_encoder) | |
self.real_basejtsrel_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.real_basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
real_basejtsrel_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.real_basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_denoising_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
print(f"not_cond_base: {self.args.not_cond_base}, latent_dim: {self.latent_dim}") | |
if self.args.use_jts_pert_realbasejtsrel: | |
print(f"use_jts_pert_realbasejtsrel!!!!!!") | |
self.real_basejtsrel_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
else: | |
if self.args.use_objbase_out_v3: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV3(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
elif self.args.use_objbase_out_v4: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV4(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
elif self.args.use_objbase_out_v5: # use_objbase_v5, use_objbase_out_v5 | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV5(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, out_objbase_v5_bundle_out=self.args.out_objbase_v5_bundle_out, v5_out_not_cond_base=self.args.v5_out_not_cond_base) | |
else: | |
self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
# OutputProcessCond | |
if self.diff_basejtsrel: ## basejtsrel ## | |
# treate them as textures of signals to model # # base pts -> dec on base pts features --> | |
# latent space denoising and feature decoding --> a little bit concern about the feature decoding process # | |
# TODO: add base_pts and base_normals to the base points -rel-to- rhand joints encoding process # | |
self.rel_input_feats = 21 * (3 + 3 + 3) # relative positions from base pts to rhand joints ## | |
# cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # | |
# Conditional strategy 1: use the relative position embeddings for guidance (cannot use the origianl weights for finetuning) # | |
layernorm = True | |
# [finetune_with_cond_rel, finetune_with_cond_jtsobj] | |
if self.args.finetune_with_cond_rel: | |
if self.args.use_objbase_v5: # use_objbase_v5 # | |
if self.args.v5_in_not_base: | |
self.rel_input_feats = (21 * 3) | |
elif self.args.v5_in_not_base_pos: | |
self.rel_input_feats = 3 + (21 * 3) | |
else: # as an additional conditions for the input and denoising # | |
self.rel_input_feats = 3 + 3 + (21 * 3) | |
self.cond_real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb, only_with_glb=True) | |
else: | |
raise ValueError(f"Must use objbase_v5 currently, others have not been implemented yet.") | |
self.cond_real_basejtsrel_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
cond_real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.cond_real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoder(cond_real_basejtsrel_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
elif self.args.finetune_with_cond_jtsobj: # cond_obj_trans_layer | |
# InputProcessObjV6(self, data_rep, input_feats, latent_dim, layernorm=True) | |
# cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_jtsobj_seqTransEncoder # | |
# TODO: cond_obj_trans_layer : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder | |
self.cond_obj_input_feats = 3 | |
# if self.args.finetune_cond_obj_feats_dim == 3: | |
# self.cond_obj_input_feats = 3 | |
self.cond_obj_input_feats = self.args.finetune_cond_obj_feats_dim # finetune cond with obj feats dim # | |
self.cond_obj_input_layer = InputProcessObjV6(self.data_rep, self.cond_obj_input_feats, self.latent_dim, layernorm=layernorm) | |
# TODO: remember to set this layer to zero! # | |
self.cond_obj_trans_layer = nn.Linear(self.latent_dim, self.latent_dim) | |
# hand_embedding + zero_trans_layer(obj_cond_embedding) | |
self.cond_joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
self.cond_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
cond_jtsobj_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.cond_seqTransEncoder = nn.TransformerEncoder(cond_jtsobj_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
else: | |
raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated!") | |
# TODO: remember to initialize it to zero! # | |
self.cond_trans_linear_layer = nn.Linear(self.latent_dim, self.latent_dim) | |
# self.avg_joints_sequence_input_process, self.avg_joint_sequence_output_process # | |
self.avg_joints_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
# TODO: should set those weights to frozen --> joints offset input process... | |
self.joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) | |
if self.args.not_cond_base: | |
self.rel_input_feats = 21 * ( 3) | |
# self.input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats+self.gru_emb_dim, self.latent_dim) | |
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
### Encoding layer ### | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
### Encoding layer ### | |
# logvar_seqTransEncoder_e, logvar_seqTransEncoder # logvar_seqTranEncoder | |
seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, | |
num_layers=self.num_layers) | |
### timesteps embedding layer ### # TimestepEmbedder -> embedding times # | |
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # baseptsrel # | |
self.basejtsrel_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
self.sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) | |
else: | |
seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# # seq_len x bsz x dim | |
# if self.args.const_noise: # add to attention network # add j | |
# # 1) max pool latents over the sequence | |
# # 2) transform the pooled latnets via the linear layer | |
# self.basejtsrel_glb_denoising_latents_trans_layer = nn.Sequential( | |
# nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
# nn.Linear(self.latent_dim * 2, self.latent_dim) | |
# ) | |
###### joints_feats_in_dim ###### # a linear transformation net with weights and bias set to zero # | |
self.avg_joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # output avgjts sequence | |
# OutputProcessCond | |
self.joint_offset_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) | |
if self.args.use_dec_rel_v2: | |
self.output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, finetune_with_cond=self.args.finetune_with_cond) | |
else: | |
# OutputProcessObjBaseRaw ## output process for basejtsrel # | |
self.output_process = OutputProcessObjBaseRaw(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) | |
##### ==== input process, communications, output process for rel, dists ==== ##### | |
if self.diff_basejtse: | |
### input process obj base ### | |
# construct input_process_e # | |
# self.input_feats_e = 21 * (3 + 3 + 3 + 1 + 1) | |
# self.input_feats_e = 21 * (3 + 3 + 1 + 1) | |
self.input_feats_e = 21 * (3 + 3 + 1 + 1 + 3 + 3 + 1) | |
self.input_process_e = InputProcessObjBase(self.data_rep, self.input_feats_e+self.gru_emb_dim, self.latent_dim) | |
self.sequence_pos_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
# # single layer transformers # ## predict relative position for each base point? # existing model | |
# if self.arch == 'trans_enc': | |
print("TRANS_ENC init") | |
seqTransEncoderLayer_e = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e, | |
num_layers=self.num_layers) | |
print("TRANS_ENC init") | |
# logvar_seqTransEncoder_e, | |
seqTransEncoderLayer_e_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.logvar_seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e_logvar, | |
num_layers=self.num_layers) | |
# | |
# elif self.arch == 'trans_dec': | |
# print("TRANS_DEC init") | |
# seqTransDecoderLayer_e = nn.TransformerDecoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=activation) | |
# self.seqTransDecoder_e = nn.TransformerDecoder(seqTransDecoderLayer_e, | |
# num_layers=self.num_layers) | |
# elif self.arch == 'gru': ## arch ## | |
# print("GRU init") | |
# self.gru_e = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) | |
# else: | |
# raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') | |
# tiemstep # # timestep embedding e # Embed timestep e # | |
self.embed_timestep_e = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) | |
self.sequence_pos_denoising_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # | |
self.basejtse_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) | |
if self.args.use_ours_transformer_enc: # our transformer encoder # | |
self.basejtse_denoising_seqTransEncoder = model_utils.TransformerEncoder( | |
hidden_size=self.latent_dim, | |
fc_size=self.ff_size, | |
num_heads=self.num_heads, | |
layer_norm=True, | |
num_layers=self.num_layers, | |
dropout_rate=0.2, | |
re_zero=True, | |
memory_efficient=False, | |
) ### basejtse_denoising_seqTransEncoder ### | |
else: | |
basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
nhead=self.num_heads, | |
dim_feedforward=self.ff_size, | |
dropout=self.dropout, | |
activation=self.activation) | |
self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, | |
num_layers=self.num_layers) | |
# basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, | |
# nhead=self.num_heads, | |
# dim_feedforward=self.ff_size, | |
# dropout=self.dropout, | |
# activation=self.activation) | |
# self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, | |
# num_layers=self.num_layers) | |
# seq_len x bsz x dim | |
# if self.args.const_noise: | |
# # 1) max pool latents over the sequence | |
# # 2) transform the pooled latnets via the linear layer | |
# self.basejtse_denoising_seqTransEncoder = nn.Sequential( | |
# nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), | |
# nn.Linear(self.latent_dim * 2, self.latent_dim) | |
# ) | |
# self.output_process_e = OutputProcessObjBaseV3(self.data_rep, self.latent_dim) | |
self.output_process_e = OutputProcessObjBaseERaw(self.data_rep, self.latent_dim) | |
# self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) | |
def set_enc_to_eval(self): | |
# jts: joint_sequence_input_process, joint_sequence_pos_encoder, joint_sequence_seqTransEncoder, joint_sequence_logvar_seqTransEncoder | |
# basejtsrel: input_process, sequence_pos_encoder, seqTransEncoder, logvar_seqTransEncoder, | |
# basejtse: input_process_e, sequence_pos_encoder_e, seqTransEncoder_e, logvar_seqTransEncoder_e, | |
if self.diff_jts: | |
self.joint_sequence_input_process.eval() | |
self.joint_sequence_pos_encoder.eval() | |
self.joint_sequence_seqTransEncoder.eval() | |
self.joint_sequence_logvar_seqTransEncoder.eval() | |
if self.diff_basejtse: | |
self.input_process_e.eval() | |
self.sequence_pos_encoder_e.eval() | |
self.seqTransEncoder_e.eval() | |
self.logvar_seqTransEncoder_e.eval() | |
if self.diff_basejtsrel: | |
self.input_process.eval() | |
self.sequence_pos_encoder.eval() | |
self.seqTransEncoder.eval() # seqTransEncoder, logvar_seqTransEncoder | |
self.logvar_seqTransEncoder.eval() | |
def set_bn_to_eval(self): | |
if self.args.use_objbase_v6: # real_basejtsrel_input_process | |
try: | |
self.real_basejtsrel_input_process.pnpp_conv_net.set_bn_no_training() | |
except: | |
pass | |
def parameters_wo_clip(self): | |
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] | |
def load_and_freeze_clip(self, clip_version): | |
clip_model, clip_preprocess = clip.load(clip_version, device='cpu', | |
jit=False) # Must set jit=False for training | |
clip.model.convert_weights( # encode | |
clip_model) # Actually this line is unnecessary since clip by default already on float16 | |
# Freeze CLIP weights | |
clip_model.eval() | |
for p in clip_model.parameters(): | |
p.requires_grad = False | |
return clip_model | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_mask_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond | |
return cond * (1. - mask) | |
else: | |
return cond | |
def encode_text(self, raw_text): | |
# raw_text - list (batch_size length) of strings with input text prompts # | |
device = next(self.parameters()).device | |
max_text_len = 20 if self.dataset in ['humanml', 'kit', 'motion_ours'] else None # Specific hardcoding for humanml dataset | |
if max_text_len is not None: | |
default_context_length = 77 | |
context_length = max_text_len + 2 # start_token + 20 + end_token | |
assert context_length < default_context_length | |
texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate | |
# print('texts', texts.shape) | |
zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) | |
texts = torch.cat([texts, zero_pad], dim=1) | |
# print('texts after pad', texts.shape, texts) | |
else: ## | |
texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate | |
return self.clip_model.encode_text(texts).float() | |
def dec_jts_only_fr_latents(self, latents_feats): | |
joint_seq_output = self.joint_sequence_output_process(latents_feats) # [bs, njoints, nfeats, nframes] | |
# bsz x ws x nnj x nnfeats # --> joints_seq_outputs # | |
joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() | |
## joints seq outputs ## | |
diff_jts_dict = { | |
"joint_seq_output": joint_seq_output, | |
"joints_seq_latents": latents_feats, | |
} | |
return diff_jts_dict | |
def dec_basejtsrel_only_fr_latents(self, latent_feats, x): | |
# basejtsrel_seq_latents_pred_feats | |
avg_jts_seq_latents = latent_feats[0:1, ...] | |
other_basejtsrel_seq_latents = latent_feats[1:, ...] | |
avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
basejtsrel_dec_out = { | |
'avg_jts_outputs': avg_jts_outputs, | |
'basejtsrel_output': basejtsrel_output['dec_rel'], | |
} | |
return basejtsrel_dec_out | |
# def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): | |
def dec_latents_to_joints_with_t(self, input_latent_feats, x, timesteps): | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
# joints_seq_latents: seq x bs x d --> perturbed joitns_seq_latents \in [-1, 1] ## | |
# def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): | |
## positional encoding for denoising ## | |
# rt_dict = { | |
# 'joint_seq_output': joint_seq_output, | |
# 'rel_base_pts_outputs': rel_base_pts_outputs, | |
# } | |
rt_dict = {} | |
if self.diff_jts: | |
####### input latent feats ####### | |
joints_seq_latents = input_latent_feats["joints_seq_latents"] | |
if not self.args.without_dec_pos_emb: | |
joints_seq_latents = self.joint_sequence_denoising_pos_encoder(joints_seq_latents) | |
# ### GET joints seq time embeddings ### ### embed time stamps ### | |
# joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
# joints_seq_latents = torch.cat( | |
# [joints_seq_time_emb, joints_seq_latents], dim=0 | |
# ) | |
# joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # seq x bs x d | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
joints_seq_time_emb = joints_seq_time_emb.repeat(joints_seq_latents.size(0), 1, 1).contiguous() | |
joints_seq_latents = joints_seq_latents + joints_seq_time_emb | |
if self.args.use_ours_transformer_enc: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
joints_seq_latents = joints_seq_latents.permute(1, 0, 2) | |
else: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) | |
joints_seq_latents = torch.cat( | |
[joints_seq_time_emb, joints_seq_latents], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
joints_seq_latents = joints_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] | |
# joints_seq_latents: seq_len x bsz x latent_dim # | |
if self.args.const_noise: | |
seq_len = joints_seq_latents.size(0) | |
# if self.args.const_noise: | |
joints_seq_latents, _ = torch.max(joints_seq_latents, dim=0, keepdim=True) | |
joints_seq_latents = self.glb_denoising_latents_trans_layer(joints_seq_latents) # seq_len x bsz x latent_dim | |
joints_seq_latents = joints_seq_latents.repeat(seq_len, 1, 1).contiguous() | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
joints_seq_latents = input_latent_feats["joints_seq_latents_enc"] | |
# bsz x ws x nnj x 3 # | |
joint_seq_output = self.joint_sequence_output_process(joints_seq_latents) # [bs, njoints, nfeats, nframes] | |
# bsz x ws x nnj x nnfeats # --> joints_seq_outputs # | |
joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() | |
diff_jts_dict = { | |
"joint_seq_output": joint_seq_output, | |
"joints_seq_latents": joints_seq_latents, | |
} | |
else: | |
diff_jts_dict = {} | |
if self.diff_basejtsrel: | |
rel_base_pts_outputs = input_latent_feats["rel_base_pts_outputs"] | |
if rel_base_pts_outputs.size(0) == 1 and self.args.single_frame_noise: | |
rel_base_pts_outputs = rel_base_pts_outputs.repeat(self.args.window_size + 1, 1, 1) | |
if not self.args.without_dec_pos_emb: # without | |
avg_jts_inputs = rel_base_pts_outputs[0:1, ...] | |
other_rel_base_pts_outputs = rel_base_pts_outputs[1: , ...] | |
other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) | |
rel_base_pts_outputs = torch.cat( | |
[avg_jts_inputs, other_rel_base_pts_outputs], dim=0 | |
) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs.size(0), 1, 1).contiguous() | |
basejtsrel_seq_latents = rel_base_pts_outputs + basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_seq_latents = torch.cat( | |
[basejtsrel_time_emb, rel_base_pts_outputs], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
basejtsrel_seq_latents = input_latent_feats["rel_base_pts_outputs_enc"] | |
if basejtsrel_seq_latents.size(0) == 1 and self.args.single_frame_noise: | |
basejtsrel_seq_latents = basejtsrel_seq_latents.repeat(self.args.window_size + 1, 1, 1) | |
basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents | |
elif self.args.pred_diff_noise: | |
basejtsrel_seq_latents_pred_feats = input_latent_feats["rel_base_pts_outputs"] - basejtsrel_seq_latents | |
else: | |
basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
### GET joints seq output ### | |
# basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # | |
# basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
# basejtsrel_seq_latents = torch.cat( | |
# [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 | |
# ) | |
# basejtsrel_seq_latents_pred_feats | |
avg_jts_seq_latents = basejtsrel_seq_latents_pred_feats[0:1, ...] | |
other_basejtsrel_seq_latents = basejtsrel_seq_latents_pred_feats[1:, ...] | |
avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
#### | |
# avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] | |
# other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] | |
# avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 | |
# avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # | |
# # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) | |
diff_basejtsrel_dict = { | |
"basejtsrel_output": basejtsrel_output['dec_rel'], | |
"basejtsrel_seq_latents": basejtsrel_seq_latents, | |
"avg_jts_outputs": avg_jts_outputs, | |
} | |
else: | |
diff_basejtsrel_dict = {} | |
if self.diff_basejtse: | |
# e_disp_rel_to_base_along_normals = input_latent_feats['e_disp_rel_to_base_along_normals'] | |
# e_disp_rel_to_baes_vt_normals = input_latent_feats['e_disp_rel_to_baes_vt_normals'] | |
base_jts_e_feats = input_latent_feats['base_jts_e_feats'] # seq x bs x d --> e feats | |
if not self.args.without_dec_pos_emb: | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() | |
base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) | |
else: | |
#### Decode e_along_normals and e_vt_normals #### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_feats = torch.cat( | |
[base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
### sequence latents ### | |
if self.args.train_enc: # trian enc for seq latents ### | |
base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] | |
# base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
#### Decode e_along_normals and e_vt_normals #### | |
### bae jts e embeddings feats ### | |
# base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
# base_jts_e_feats = torch.cat( | |
# [base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
# ) | |
# base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
base_jts_e_output = self.output_process_e(base_jts_e_feats, x) | |
# dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # | |
dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] | |
dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] | |
dec_d = basejtsrel_output['dec_d'] | |
rel_vel_dec = basejtsrel_output['rel_vel_dec'] | |
dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() | |
dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() | |
dec_d = dec_d.contiguous().permute(0, 1, 3, 2).contiguous() | |
rel_vel_dec = rel_vel_dec.contiguous().permute(0, 1, 3, 2).contiguous() | |
#### Decode e_along_normals and e_vt_normals #### | |
diff_basejtse_dict = { | |
'dec_e_along_normals': dec_e_along_normals, | |
'dec_e_vt_normals': dec_e_vt_normals, | |
'base_jts_e_feats': base_jts_e_feats, | |
'dec_d': dec_d, | |
'rel_vel_dec': rel_vel_dec, | |
} | |
else: | |
diff_basejtse_dict = {} | |
rt_dict = {} | |
rt_dict.update(diff_jts_dict) | |
rt_dict.update(diff_basejtsrel_dict) | |
rt_dict.update(diff_basejtse_dict) | |
### rt_dict --> rt_dict of joints, rel ### | |
return rt_dict | |
# return joint_seq_output, joints_seq_latents | |
def reparameterization(self, val_mean, val_var): | |
val_noise = torch.randn_like(val_mean) | |
val_sampled = val_mean + val_noise * val_var ### sample the value | |
if self.args.rnd_noise: | |
val_sampled = val_noise | |
return val_sampled | |
def decode_realbasejtsrel_from_objbasefeats(self, objbasefeats, input_data): | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
objbasefeats, input_data | |
) | |
# real_dec_basejtsrel -> decoded realtive positions # | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() | |
return real_dec_basejtsrel | |
def denoising_realbasejtsrel_objbasefeats(self, pert_obj_base_pts_feats, timesteps): | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(pert_obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = pert_obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, pert_obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# bsz, nframes, nnb, nnj, 3 --> | |
# real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
# real_basejtsrel_seq_latents, x | |
# ) | |
return real_basejtsrel_seq_latents | |
def get_cond_parameters(self): | |
# [finetune_with_cond_rel, finetune_with_cond_jtsobj] | |
# cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # | |
# finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder | |
# cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_seqTransEncoder # | |
if self.args.diff_basejtsrel: | |
# TODO: finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder | |
if self.args.finetune_with_cond_rel: | |
params = list(self.cond_real_basejtsrel_input_process.parameters()) + list(self.cond_real_basejtsrel_pos_encoder.parameters()) + list(self.cond_real_basejtsrel_seqTransEncoderLayer.parameters()) + list(self.cond_trans_linear_layer.parameters()) | |
elif self.args.finetune_with_cond_jtsobj: | |
params = list(self.cond_joints_offset_input_process.parameters()) + list(self.cond_sequence_pos_encoder.parameters()) + list(self.cond_seqTransEncoder.parameters()) + list(self.cond_obj_trans_layer.parameters()) + list(self.cond_obj_input_layer.parameters()) | |
else: | |
raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated !") | |
else: | |
raise ValueError(f"Must use diff_basejtsrel currently, others have not been implemented yet.") | |
return params | |
def set_trans_linear_layer_to_zero(self): | |
torch.nn.init.zeros_(self.cond_trans_linear_layer.weight) | |
torch.nn.init.zeros_(self.cond_trans_linear_layer.bias) | |
# finetune_with_cond_jtsobj: # cond_obj_trans_layer | |
if self.args.finetune_with_cond_jtsobj: | |
torch.nn.init.zeros_(self.cond_obj_trans_layer.weight) | |
torch.nn.init.zeros_(self.cond_obj_trans_layer.bias) | |
def forward(self, x, timesteps): | |
""" | |
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper | |
timesteps: [batch_size] (int) | |
""" | |
# joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # | |
# bsz, nframes, nnj = x['pert_rhand_joints'].shape[:3] | |
# pert_rhand_joints = x['pert_rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws | |
bsz, nframes, nnj = x['rhand_joints'].shape[:3] | |
pert_rhand_joints = x['rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws | |
base_pts = x['base_pts'] ### bsz x nnb x 3 ### | |
base_normals = x['base_normals'] ### bsz x nnb x 3 ### --> base normals ### | |
# base_normals # ## | |
rt_dict = {} | |
## # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
if self.diff_basejtse: | |
### Embed physicss quantities ### | |
# e_disp_rel_to_base_along_normals: bsz x (ws - 1) x nnj x nnb # | |
# e_disp_rel_to_baes_vt_normals: bsz x (ws - 1) x nnj x nnb # | |
# e_disp_rel_to_base_along_normals = x['e_disp_rel_to_base_along_normals'] | |
# e_disp_rel_to_baes_vt_normals = x['e_disp_rel_to_baes_vt_normals'] | |
e_disp_rel_to_base_along_normals = x['pert_e_disp_rel_to_base_along_normals'] | |
e_disp_rel_to_baes_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] | |
obj_pts_disp = x['obj_pts_disp'] | |
vel_obj_pts_to_hand_pts = x['vel_obj_pts_to_hand_pts'] | |
disp_dist = x['disp_dist'] | |
nnb = base_pts.size(1) | |
disp_ws = e_disp_rel_to_base_along_normals.size(1) ### --> base normals ### | |
base_pts_disp_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() | |
base_normals_disp_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() | |
obj_pts_disp_exp = obj_pts_disp.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() | |
# bsz x (ws - 1) x nnj x nnb x (3 + 3 + 1 + 1) | |
base_pts_normals_e_in_feats = torch.cat( # along normals; # vt normals # | |
[base_pts_disp_exp, base_normals_disp_exp, obj_pts_disp_exp, vel_obj_pts_to_hand_pts, disp_dist, e_disp_rel_to_base_along_normals.unsqueeze(-1), e_disp_rel_to_baes_vt_normals.unsqueeze(-1)], dim=-1 | |
) | |
base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.permute(0, 1, 3, 2, 4).contiguous() | |
# bsz x (ws - 1) x nnb x (nnj x (xxx feats_dim)) | |
base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.view(bsz, disp_ws, nnb, -1).contiguous() | |
## input process ## | |
base_jts_e_feats = self.input_process_e(base_pts_normals_e_in_feats) | |
base_jts_e_feats = self.sequence_pos_encoder_e(base_jts_e_feats) | |
## seq transformation for e ## # | |
base_jts_e_feats_mean = self.seqTransEncoder_e(base_jts_e_feats) ## mean, mdm_ours ## | |
# print(f"base_jts_e_feats: {base_jts_e_feats.size()}") | |
### Embed physicss quantities ### | |
#### base_jts_e_feats, base_jts_e_feats_mean #### | |
# ## us basejtsefeats for denoising directly ## | |
base_jts_e_feats = base_jts_e_feats_mean | |
if not self.args.without_dec_pos_emb: ## use positional encoding ## | |
# rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) | |
base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() | |
base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) | |
else: | |
#### Decode e_along_normals and e_vt_normals #### | |
base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
base_jts_e_feats = torch.cat( | |
[base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## transformer encoder ## | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] | |
else: | |
base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
# ### sequence latents ### | |
# if self.args.train_enc: # trian enc for seq latents ### | |
# base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] | |
# base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) | |
#### Decode e_along_normals and e_vt_normals #### | |
### bae jts e embeddings feats ### | |
# base_jts_e_time_emb = self.embed_timestep_e(timesteps) | |
# base_jts_e_feats = torch.cat( | |
# [base_jts_e_time_emb, base_jts_e_feats], dim=0 | |
# ) | |
# base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] | |
##### output_process_e -> output energies ##### | |
base_jts_e_output = self.output_process_e(base_jts_e_feats, x) | |
# dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # | |
dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] | |
dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] | |
dec_d = base_jts_e_output['dec_d'] | |
rel_vel_dec = base_jts_e_output['rel_vel_dec'] | |
dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb | |
dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb | |
#### Decode e_along_normals and e_vt_normals #### | |
diff_basejtse_dict = { | |
'dec_e_along_normals': dec_e_along_normals, | |
'dec_e_vt_normals': dec_e_vt_normals, | |
'base_jts_e_feats': base_jts_e_feats, | |
'dec_d': dec_d, | |
'rel_vel_dec': rel_vel_dec, | |
} | |
# rt_dict['base_jts_e_feats'] = base_jts_e_feats | |
# rt_dict['base_jts_e_feats_mean'] = base_jts_e_feats_mean | |
# rt_dict['base_jts_e_feats_logvar'] = base_jts_e_feats_logvar # log_var # | |
else: | |
diff_basejtse_dict = {} | |
if self.diff_jts: | |
# base_pts_normal | |
### InputProcess ### | |
pert_rhand_joints_trans = pert_rhand_joints.permute(0, 2, 3, 1).contiguous() # bsz x nnj x 3 x ws # | |
rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints_trans) # [seqlen, bs, d] | |
### InputProcessObjBase ### | |
# rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints) | |
### === Encode input joint sequences === ### | |
# bs, njoints, nfeats, nframes = x.shape | |
# rhand_joints_emb = self.joint_sequence_embed_timestep(timesteps) # [1, bs, d] | |
# if self.arch == 'trans_enc': | |
xseq = rhand_joints_feats # [seqlen+1, bs, d] | |
xseq = self.joint_sequence_pos_encoder(xseq) # [seqlen+1, bs, d] | |
joint_seq_output_mean = self.joint_sequence_seqTransEncoder(xseq) # [1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] | |
### calculate logvar, mean, and feats ### | |
joint_seq_output_logvar = self.joint_sequence_logvar_seqTransEncoder(xseq) | |
# # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # | |
joint_seq_output_var = torch.exp(joint_seq_output_logvar) # seq x bs x d --> encodeing and decoding | |
## base_jts_e_feats: seqlen x bs x d --> val latents ## | |
joint_seq_output = self.reparameterization(joint_seq_output_mean, joint_seq_output_var) | |
rt_dict['joint_seq_output'] = joint_seq_output | |
# rt_dict['joint_seq_output'] = joint_seq_output_mean | |
rt_dict['joint_seq_output_mean'] = joint_seq_output_mean | |
rt_dict['joint_seq_output_logvar'] = joint_seq_output_logvar | |
if self.args.diff_realbasejtsrel_to_joints: # nframes x nnbase x nnjts x (base pts + base normals + 3) 2) point feature for each point; point feature for; condition on the noisy input for the denoised information | |
# real_basejtsrel_to_joints_input_process, real_basejtsrel_to_joints_sequence_pos_encoder, real_basejtsrel_to_joints_seqTransEncoder | |
# real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process | |
bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_joints_for_jts_pred'].size()[:4] | |
normed_base_pts = x['normed_base_pts'] | |
base_normals = x['base_normals'] | |
pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_joints_for_jts_pred'] # bsz x nf x nnj x nnb x 3 | |
## use_abs_jts_pos --> obj jts pos for the encodingj | |
if self.args.use_abs_jts_pos: ## bsz x nf x nnj x nnb x 3 ## ---> abs jts pos ## | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) | |
# use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process | |
if self.args.use_abs_jts_for_encoding: | |
if not self.args.use_abs_jts_pos: | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 | |
abs_jts = pert_rel_base_pts_to_rhand_joints[..., 0, :] | |
abs_jts = abs_jts.permute(0, 2, 3, 1).contiguous() | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(abs_jts) | |
elif self.args.use_abs_jts_for_encoding_obj_base: | |
if not self.args.use_abs_jts_pos: | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] | |
# obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
# [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], pert_rel_base_pts_to_rhand_joints], dim=-1 | |
# ) | |
obj_base_in_feats = pert_rel_base_pts_to_rhand_joints | |
# --> tnrasform the input feature dim to 21 * 3 here for encoding # | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, 1, -1).contiguous() # | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # | |
else: | |
if self.args.use_objbase_v2: | |
# bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
[normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # | |
# ### real_basejtsrel_to_joints_input_process --> real_basejtsrel_to_joints_input_process --> for the joints and input process ### # obj_base_encoded_feats | |
# obj_base_encoded_feats | |
obj_base_pts_feats_pos_embedding = self.real_basejtsrel_to_joints_sequence_pos_encoder(obj_base_encoded_feats) | |
obj_base_pts_feats = self.real_basejtsrel_to_joints_seqTransEncoder(obj_base_pts_feats_pos_embedding) | |
if self.args.use_sigmoid: | |
obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq des | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
joints_offset_output = self.real_basejtsrel_to_joints_output_process(real_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) | |
if self.args.diff_basejtsrel: | |
diff_basejtsrel_to_joints_dict = { | |
'joints_offset_output_from_rel': joints_offset_output | |
} | |
else: | |
diff_basejtsrel_to_joints_dict = { | |
'joints_offset_output': joints_offset_output | |
} | |
else: | |
diff_basejtsrel_to_joints_dict = {} | |
if self.diff_realbasejtsrel: # real_dec_basejtsrel | |
# real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder | |
bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] | |
normed_base_pts = x['normed_base_pts'] | |
base_normals = x['base_normals'] | |
pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 | |
if self.args.use_objbase_v2: | |
# bsz x nf x nnj x nnb x 3 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.use_objbase_v4: | |
# use_objbase_v4: # use_objbase_out_v4 | |
exp_normed_base_pts = normed_base_pts.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() | |
exp_base_normals = base_normals.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[pert_rel_base_pts_to_rhand_joints, exp_normed_base_pts, exp_base_normals], dim=-1 # bsz x nf x nnj x nnb x (3 + 3 + 3) # -> exp_base_normals | |
) | |
obj_base_in_feats = obj_base_in_feats.view(bsz, nf, nnj, -1).contiguous() | |
elif self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
if self.args.v5_in_not_base: | |
obj_base_in_feats = torch.cat( | |
[ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.v5_in_not_base_pos: | |
obj_base_in_feats = torch.cat( | |
[base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.use_objbase_v6 or self.args.use_objbase_v7: | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
[normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
# obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) | |
# [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 | |
# ) | |
# print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") | |
# obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # | |
if self.args.use_objbase_v6: | |
normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1) # and repeat for the base pts # | |
obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats, normed_base_pts_exp) | |
else: | |
obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats) # nf x bsz x feat_dim # nf x bsz x nnbasepts x feats_dim # | |
# obj_base_encoded_feats | |
obj_base_pts_feats_pos_embedding = self.real_basejtsrel_sequence_pos_encoder(obj_base_encoded_feats) | |
obj_base_pts_feats = self.real_basejtsrel_seqTransEncoder(obj_base_pts_feats_pos_embedding) | |
if self.args.use_sigmoid: | |
obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. | |
if self.args.train_enc: | |
# basejtsrel_seq | |
# bsz, nframes, nnb, nnj, 3 --> | |
# | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
obj_base_pts_feats, x | |
) | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
if self.args.use_objbase_out_v3: | |
real_dec_basejtsrel = real_dec_basejtsrel | |
else: | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 | |
else: | |
if self.args.deep_fuse_timeemb: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
# print(f"timesteps: {timesteps.size()}, obj_base_pts_feats: {obj_base_pts_feats.size()}") | |
if self.args.use_objbase_v5: | |
cur_timesteps = timesteps.unsqueeze(1).repeat(1, nnb).view(-1) | |
else: | |
cur_timesteps = timesteps | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(cur_timesteps) | |
real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() | |
real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) | |
else: # seq des | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) | |
real_basejtsrel_seq_latents = torch.cat( | |
[real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
# basejtsrel_seq | |
# bsz, nframes, nnb, nnj, 3 --> | |
# | |
if self.args.use_jts_pert_realbasejtsrel: | |
joints_offset_output = self.real_basejtsrel_output_process(real_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) # bsz x nf x nnj x 3 | |
real_dec_basejtsrel = joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, nnb, 1) | |
# real_dec_basejtsrel = joints_offset_output | |
else: | |
real_dec_basejtsrel = self.real_basejtsrel_output_process( | |
real_basejtsrel_seq_latents, x | |
) | |
# real_dec_basejtsrel -> decoded realtive positions # | |
real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] | |
if self.args.use_objbase_out_v3 or self.args.use_objbase_out_v4 or self.args.use_objbase_out_v5: | |
real_dec_basejtsrel = real_dec_basejtsrel | |
else: | |
real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 | |
diff_realbasejtsrel_out_dict = { | |
'real_dec_basejtsrel': real_dec_basejtsrel, | |
'obj_base_pts_feats': obj_base_pts_feats, | |
} | |
else: | |
diff_realbasejtsrel_out_dict = {} | |
if self.diff_basejtsrel: | |
bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] | |
# cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # | |
normed_base_pts = x['normed_base_pts'] | |
base_normals = x['base_normals'] | |
pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 | |
# conditional strategies and representations for denoising # | |
# conditional strategies | |
# [finetune_with_cond_rel, finetune_with_cond_jtsobj] # | |
# cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # | |
# finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder | |
if self.args.finetune_with_cond_rel: # finetune with cond rel # finetune with cond # finetune with | |
if self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
if self.args.v5_in_not_base: | |
obj_base_in_feats = torch.cat( | |
[ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
elif self.args.v5_in_not_base_pos: | |
obj_base_in_feats = torch.cat( | |
[base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
obj_base_in_feats = torch.cat( | |
[normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: | |
raise ValueError(f"Must use objbase_v5 currently, others have not been implemented yet.") | |
# print(f'obj_base_in_feats: {obj_base_in_feats.size()}') | |
obj_base_encoded_feats = self.cond_real_basejtsrel_input_process(obj_base_in_feats) | |
obj_base_pts_feats_pos_embedding = self.cond_real_basejtsrel_pos_encoder(obj_base_encoded_feats) | |
# print(f"obj_base_pts_feats_pos_embedding: {obj_base_pts_feats_pos_embedding.size()}") | |
obj_base_pts_feats = self.cond_real_basejtsrel_seqTransEncoderLayer(obj_base_pts_feats_pos_embedding) # seq x bsz x latent_dim | |
elif self.args.finetune_with_cond_jtsobj: | |
# cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_seqTransEncoder # | |
cond_joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 | |
cond_joints_offset_sequence = cond_joints_offset_sequence.permute(0, 2, 3, 1).contiguous() | |
cond_joints_offset_feats = self.cond_joints_offset_input_process(cond_joints_offset_sequence) # nf x bsz x dim | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
# | |
if self.cond_obj_input_feats == 3: | |
normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() | |
elif self.cond_obj_input_feats == 6: | |
normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() | |
normed_base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() | |
normed_base_pts_exp = torch.cat( | |
[normed_base_pts_exp, normed_base_normals_exp], dim=-1 ## bsz x nf x nn_base_pts x (3 + 3) | |
) | |
elif self.cond_obj_input_feats == (6 + 21 * 3): # 63 + 6 -> 69 | |
normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() | |
normed_base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() | |
pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() | |
normed_base_pts_exp = torch.cat( | |
[normed_base_pts_exp, normed_base_normals_exp, pert_rel_base_pts_to_rhand_joints_exp], dim=-1 | |
) | |
else: # cond_obj_input_feats --> cond_obj_input_feats # finetune_cond_obj_feats_dim | |
raise ValueError(f"Unrecognized cond_obj_input_feats: {self.cond_obj_input_feats}") | |
cond_obj_feats = self.cond_obj_input_layer(normed_base_pts_exp) # nf x bsz x feats_dim # | |
cond_obj_feats = self.cond_obj_trans_layer(cond_obj_feats) | |
cond_jtsobj_feats = cond_joints_offset_feats + cond_obj_feats | |
cond_jtsobj_feats = self.cond_sequence_pos_encoder(cond_jtsobj_feats) | |
obj_base_pts_feats = self.cond_seqTransEncoder(cond_jtsobj_feats) | |
else: | |
raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated! ") | |
obj_base_pts_feats = self.cond_trans_linear_layer(obj_base_pts_feats) | |
# print(f"obj_base_pts_feats: {obj_base_pts_feats.size()}") | |
### ==== Encoding joints ==== ### | |
joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 | |
joints_offset_sequence = joints_offset_sequence.permute(0, 2, 3, 1).contiguous() | |
joints_offset_feats = self.joints_offset_input_process(joints_offset_sequence) # nf x bsz x dim | |
rel_base_pts_feats_pos_embedding = self.sequence_pos_encoder(joints_offset_feats) | |
rel_base_pts_feats = rel_base_pts_feats_pos_embedding | |
# seqTransEncoder, logvar_seqTransEncoder # | |
rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) | |
### ==== Encoding joints ==== ### | |
### ==== fuse conditional embeddings with joints embeddigns ==== ### | |
rel_base_pts_outputs_mean = rel_base_pts_outputs_mean + obj_base_pts_feats | |
### ==== fuse conditional embeddings with joints embeddigns ==== ### | |
# print(f"rel_base_pts_outputs_mean: {rel_base_pts_outputs_mean.size()}") | |
### ==== denoise latent features ==== ### | |
if not self.args.without_dec_pos_emb: # without dec pos embedding | |
# avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] | |
other_rel_base_pts_outputs = rel_base_pts_outputs_mean # | |
other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) | |
rel_base_pts_outputs_mean = other_rel_base_pts_outputs | |
if self.args.deep_fuse_timeemb: | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() | |
basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb | |
if self.args.use_ours_transformer_enc: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) | |
else: | |
## denoising process ### | |
## GET joints seq time embeddings ### ### embed time stamps ### | |
basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) | |
basejtsrel_seq_latents = torch.cat( | |
[basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 | |
) | |
if self.args.use_ours_transformer_enc: ## mdm ours ## | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) | |
basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] | |
else: | |
basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] | |
### ==== denoise latent features ==== ### | |
other_basejtsrel_seq_latents = basejtsrel_seq_latents # [1:, ...] | |
joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) | |
joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) | |
diff_basejtsrel_dict = { | |
'joints_offset_output': joints_offset_output | |
} | |
else: | |
diff_basejtsrel_dict = {} | |
rt_dict = {} | |
rt_dict.update(diff_basejtsrel_dict) | |
rt_dict.update(diff_basejtse_dict) ### rt_dict and diff_basejtse | |
rt_dict.update(diff_basejtsrel_to_joints_dict) | |
rt_dict.update(diff_realbasejtsrel_out_dict) ### diff | |
return rt_dict | |
def _apply(self, fn): | |
super()._apply(fn) | |
# self.rot2xyz.smpl_model._apply(fn) | |
def train(self, *args, **kwargs): | |
super().train(*args, **kwargs) | |
# self.rot2xyz.smpl_model.train(*args, **kwargs) | |
### LayerNorm layer ### | |
class LayerNorm(nn.Module): | |
def __init__(self, dim, eps=1e-5): | |
super().__init__() | |
self.dim = dim | |
self.eps = eps | |
def forward(self, x): | |
x = (x - x.mean(dim=self.dim, keepdim=True)) / torch.sqrt(x.var(dim=self.dim, keepdim=True)+self.eps) | |
return x | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# not used in the final model | |
x = x + self.pe[:x.shape[0], :] | |
return self.dropout(x) | |
class TimestepEmbedder(nn.Module): | |
def __init__(self, latent_dim, sequence_pos_encoder): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.sequence_pos_encoder = sequence_pos_encoder | |
time_embed_dim = self.latent_dim | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.latent_dim, time_embed_dim), | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, time_embed_dim), | |
) | |
def forward(self, timesteps): | |
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) | |
## hand sequence; object shape | |
# InputProcessObjBase(self, data_rep, input_feats, latent_dim) | |
class InputProcessObjBaseV5(nn.Module): # inputobjbase | |
def __init__(self, data_rep, input_feats, latent_dim, layernorm=True, without_glb=False, only_with_glb=False): | |
super().__init__() | |
self.data_rep = data_rep | |
self.input_feats = input_feats # 21 * 3 + 3 + 3 --> for each joint + 3 pos + 3 normals # | |
self.latent_dim = latent_dim | |
self.pts_feats_encoding_net = nn.Sequential( # nnb --> 21 | |
nn.Linear(input_feats, self.latent_dim), nn.ReLU(), | |
nn.Linear(self.latent_dim, self.latent_dim), | |
) | |
self.glb_feats_encoding_net = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), | |
nn.Linear(self.latent_dim, self.latent_dim), | |
) | |
self.without_glb = without_glb | |
self.only_with_glb = only_with_glb | |
if self.without_glb: | |
self.pts_glb_feats_encoding_net = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), | |
nn.Linear(self.latent_dim, self.latent_dim), | |
) | |
else: | |
self.pts_glb_feats_encoding_net = nn.Sequential( | |
nn.Linear(self.latent_dim * 2, self.latent_dim), nn.ReLU(), | |
nn.Linear(self.latent_dim, self.latent_dim), | |
) | |
# self.embedding_pn_blk = nn.Sequential( # nnb --> 21 | |
# nn.Linear(input_feats, self.latent_dim), nn.ReLU(), | |
# nn.Linear(self.latent_dim, self.latent_dim), | |
# ) | |
def forward(self, x): # decode relative positions and others ## | |
# bs, nframes, njoints, nfeats = x.shape # | |
# bsz x nf x nnj x (3 + nnb x (3 + 3)) # bsz x nf x nnb x (latent_dim) | |
# x: bsz x nf x nnb x (3 + 3 + 21 * 3) # x.size() | |
bsz, nf, nnb = x.size()[:3] | |
if self.only_with_glb: | |
x_pts_emb = self.pts_feats_encoding_net( # input # noisy --- too noisy # | |
x | |
) | |
x_glb_emb, _ = torch.max(x_pts_emb, dim=2, keepdim=True) | |
x_glb_emb = self.glb_feats_encoding_net(x_glb_emb) # bsz x nf x latent_dim | |
x_glb_emb = x_glb_emb.squeeze(-2) | |
x_pts_emb = x_glb_emb.permute(1, 0, 2).contiguous() | |
else: | |
x_pts_emb = self.pts_feats_encoding_net( # input # noisy --- too noisy # | |
x | |
) | |
if not self.without_glb: | |
x_glb_emb, _ = torch.max(x_pts_emb, dim=2, keepdim=True) | |
x_glb_emb = self.glb_feats_encoding_net(x_glb_emb) # bsz x nf x 1 x latnet_dim | |
x_pts_emb = torch.cat( # 1 | |
[x_pts_emb, x_glb_emb.repeat(1, 1, nnb, 1)], dim=-1 | |
) | |
x_pts_emb = self.pts_glb_feats_encoding_net(x_pts_emb) # bsz x nf x nn_base_pts x latent_dim # | |
x_pts_emb = x_pts_emb.permute(1, 0, 2, 3) # nf x bsz x nn_base_pts x latent_dim # | |
x_pts_emb = x_pts_emb.contiguous().view(x_pts_emb.size(0), bsz * nnb, -1).contiguous() # nf x (bsz x nn_base_pts) x latent_dim # | |
return x_pts_emb | |
class OutputProcessObjBaseRawV5(nn.Module): | |
def __init__(self, data_rep, latent_dim, not_cond_base=False, out_objbase_v5_bundle_out=False, v5_out_not_cond_base=False, nn_keypoints=21): | |
super().__init__() | |
self.data_rep = data_rep | |
# self.input_feats = input_feats | |
self.latent_dim = latent_dim | |
# self.njoints = njoints | |
self.not_cond_base = not_cond_base ## not cond base ## | |
# self.nfeats = nfeats # dec cond on latent code and base pts, base normals # | |
self.v5_out_not_cond_base = v5_out_not_cond_base | |
if self.not_cond_base: | |
self.rel_dec_cond_dim = self.latent_dim | |
self.dist_dec_cond_dim = self.latent_dim | |
else: | |
self.rel_dec_cond_dim = self.latent_dim + 3 + 3 + 3 | |
self.dist_dec_cond_dim = self.latent_dim + 3 + 3 | |
# self.use_anchors = use_anchors | |
self.nn_keypoints = nn_keypoints | |
# if self.use_anchors: | |
# self.nn_keypoints = | |
# self.rel_dec_blk = nn.Sequential( | |
# nn.Linear(self.rel_dec_cond_dim, 3,), | |
# ) | |
self.out_objbase_v5_bundle_out = out_objbase_v5_bundle_out | |
if self.out_objbase_v5_bundle_out: | |
if self.v5_out_not_cond_base: | |
self.rel_dec_blk = nn.Sequential( | |
nn.Linear(self.latent_dim, self.latent_dim // 2), nn.ReLU(), | |
nn.Linear(self.latent_dim // 2, self.nn_keypoints * 3), | |
) | |
else: | |
self.rel_dec_blk = nn.Sequential( | |
nn.Linear(self.latent_dim + 3 + 3, self.latent_dim // 2), nn.ReLU(), | |
nn.Linear(self.latent_dim // 2, self.nn_keypoints * 3), | |
) | |
else: | |
self.rel_dec_blk = nn.Sequential( | |
nn.Linear(self.rel_dec_cond_dim, 3,), | |
) | |
# self.rel_dec_blk = nn.Linear( # rel_dec_blk -> output relative positions # | |
# self.rel_dec_cond_dim, 3 * 21 | |
# ) | |
self.dist_dec_cond_dim = self.latent_dim + 3 + 3 | |
self.dist_dec_blk = nn.Linear( # dist_dec_blk -> output relative distances # | |
self.dist_dec_cond_dim, 1 * self.nn_keypoints | |
) | |
# self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) | |
# if self.data_rep == 'rot_vel': | |
# self.velFinal = nn.Linear(self.latent_dim, self.input_feats) | |
def forward(self, output, x): # output | |
# nframes, bs, d = output.shape | |
# bsz, nframes, nnj, nnb = x['rel_base_pts_to_rhand_joints'].shape[:4] # pert_rel_base_pts_to_rhand_joints | |
bsz, nframes, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].shape[:4] # bsz x nf x nnj x nnb x 3 # nf x nnb x 3 --> noisy input for denoised values # | |
# forward the samole # base_pts, base_normals, # | |
# base_pts = x['base_pts'] # bsz x nnb x 3 | |
base_pts = x['normed_base_pts'] # bsz x nnb x 3 | |
base_normals = x['base_normals'] # bsz x nnb x 3 | |
# rel_base_pts_to_rhand_joints = x['rel_base_pts_to_rhand_joints'] # bsz x ws x nnj x nnb x 3 | |
# dist_base_pts_to_rhand_joints = x['dist_base_pts_to_rhand_joints'] # bsz x ws x nnj x nnb | |
## | |
# output: bsz x nf x nnj x latent_dim | |
output = output.view(nframes, bsz, nnb, -1) # nframes x bsz x nnb x latent_dim | |
output = output.permute(1, 0, 2, 3) # bsz x nnf x nnb x latent_dim ### for the output_dim # | |
if self.out_objbase_v5_bundle_out: | |
if self.v5_out_not_cond_base: | |
output_exp = output | |
else: # otuptu_exp for rel_dec_blk | |
base_pts_exp = base_pts.unsqueeze(1).repeat(1, nframes, 1, 1) | |
base_normals_exp = base_normals.unsqueeze(1).repeat(1, nframes, 1, 1) | |
output_exp = torch.cat( # with input noisy data # ############### denoised latents for each base pts ### | |
[output, base_pts_exp, base_normals_exp], dim=-1 | |
) | |
dec_rel = self.rel_dec_blk(output_exp) | |
dec_rel = dec_rel.view(bsz, nframes, nnb, nnj, 3).permute(0, 1, 3, 2, 4).contiguous() | |
else: | |
# output = output.permute(1, 0, 2) | |
# output = output.view(bsz, nframes, nnj, -1).contiguous() # bsz x nf x nnj x (decoded_latent_dim) # | |
output = output.unsqueeze(2).repeat(1, 1, nnj, 1, 1).contiguous() | |
# bsz x nnframes x d # # | |
# output = output.permute(1, 0, 2).contiguous().unsqueeze(2).unsqueeze(2).repeat(1, 1, nnj, nnb, 1).contiguous() | |
base_pts_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nframes, nnj, 1, 1) | |
base_normals_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nframes, nnj, 1, 1) | |
# bsz x nnframes x nnb x (d + 3 + 3) # --> base normals ## | |
# if self.not_cond_base: | |
# output_exp = output | |
# else: | |
output_exp = torch.cat( # with input noisy data | |
[output, base_pts_exp, base_normals_exp, x['pert_rel_base_pts_to_rhand_joints']], dim=-1 | |
) | |
dec_rel = self.rel_dec_blk(output_exp) # bsz x nnframes x nnb x (21 * 3) --> decoded relative positions # | |
dec_rel = dec_rel.contiguous().view(bsz, nframes, nnj, nnb, 3).contiguous() # bsz x nnframes x nnb x nnj x 3 # | |
# decoded rel, decoded distances # | |
out = { | |
'dec_rel': dec_rel, | |
# 'dec_dist': dec_dist.squeeze(-1), | |
} | |
return out ## output | |