import numpy as np import torch import torch.nn as nn # import torch.nn.functional as F import clip # from model.rotation2xyz import Rotation2xyz import utils.model_utils as model_utils # from model.PointNet2 import PointnetPP # from model.DGCNN import PrimitiveNet # from utils.anchor_utils import masking_load_driver, anchor_load_driver ### load driver; masking load driver # # import os ### MDM 10 ### class MDMV10(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): super().__init__() self.legacy = legacy self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_actions = num_actions self.data_rep = data_rep self.dataset = dataset self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.clip_dim = clip_dim self.action_emb = kargs.get('action_emb', None) self.input_feats = self.njoints * self.nfeats self.normalize_output = kargs.get('normalize_encoder_output', False) self.cond_mode = kargs.get('cond_mode', 'no_cond') self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) # ### GET args ### self.args = kargs.get('args', None) ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel self.diff_realbasejtsrel_to_joints = self.args.diff_realbasejtsrel_to_joints ### GET the diff. suit ### self.arch = arch ## ==== gru_emb_dim ==== ## # gru emb dim # self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0 # joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # ###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### self.use_anchors = self.args.use_anchors # if self.use_anchors: # use anchors # anchor_load_driver, masking_load_driver # # # anchor_load_driver, masking_load_driver # # inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ## # fvi, aw, _, _ = anchor_load_driver(inpath) # self.face_vertex_index = torch.from_numpy(fvi).long() # self.anchor_weight = torch.from_numpy(aw).float() # anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor") # palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt") # hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path) # # self.hand_palm_vertex_mask for hand palm mask # # self.hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool() ## the mask for hand palm to get hand anchors # # self.nn_anchors = int(self.hand_palm_vertex_mask.float().sum()) #### number of anchors here ### # self.joints_feats_in_dim = 21 * 3 # joints feats in dim # self.nn_keypoints = 21 # if self.args.use_anchors: # # self.nn_keypoints = self.nn_anchors # nn_anchors # # self.nn_keypoints = 32 # nn_anchors # self.joints_feats_in_dim = self.nn_keypoints * 3 self.data_rep = "xyz" if self.diff_jts: ## Input process for joints ## self.joint_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) # # InputProcessObjBase(self, data_rep, input_feats, latent_dim) # self.joint_sequence_input_process = InputProcessObjBase(self.data_rep, 3, self.latent_dim) # self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim) self.joint_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec # if self.arch == 'trans_enc': print("TRANS_ENC init") ## transformer encoder layer ## UNet seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) ## Joint sequence transformer encoder layer ## # sequence encoder # self.joint_sequence_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) ### logvar for the encoding laeyer and # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) ## Joint sequence transformer encoder layer ## # sequence encoder # self.joint_sequence_logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, num_layers=self.num_layers) # elif self.arch == 'trans_dec': # print("TRANS_DEC init") # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # num_heads # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) # self.joint_sequence_seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, # num_layers=self.num_layers) # elif self.arch == 'gru': # print("GRU init") # self.joint_sequence_gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) # else: # raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') ### joint sequence embed timestep ## ## timestep self.joint_sequence_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) # self.joint_sequence_output_process = OutputProcess(self.data_rep, self.latent_dim) # (self, data_rep, input_feats, latent_dim, njoints, nfeats): #### ====== joint sequence denoising block ====== #### ## seqTransEncoder ## self.joint_sequence_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) self.joint_sequence_denoising_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: self.joint_sequence_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) # seq_len x bsz x dim if self.args.const_noise: # 1) max pool latents over the sequence # 2) transform the pooled latnets via the linear layer self.glb_denoising_latents_trans_layer = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), nn.Linear(self.latent_dim * 2, self.latent_dim) ) # refinement for predicted joints # --> not in the paradigm of generation # # seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=self.activation) # self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, # num_layers=self.num_layers) #### ====== joint sequence denoiisng block ====== #### ### Output process ### output proces for joint sequence ### # output proces --> datarep, joints feats in dim, latent dim ## ###### joints_feats_in_dim ###### self.joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # OutputProcessCond # self.joint_sequence_output_process = OutputProcessCond(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) ###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### # real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process if self.diff_realbasejtsrel_to_joints: # feature for each joint point? --> for the denoising purpose # # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder layernorm = True self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions if self.args.use_abs_jts_for_encoding_obj_base: self.rel_input_feats = 21 * (3) # layernorm = False self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) # self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) # elif self.args.use else: if self.args.use_objbase_v2: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v3: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV3(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) else: self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) if self.args.use_abs_jts_for_encoding: # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process self.real_basejtsrel_to_joints_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) self.real_basejtsrel_to_joints_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) ### Encoding layer ### # InputProcessObjBaseV2 real_basejtsrel_to_joints_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_to_joints_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_seqTransEncoderLayer, # basejtsrel_seqTrans num_layers=self.num_layers) ### timesteps embedding layer ### self.real_basejtsrel_to_joints_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_to_joints_sequence_pos_encoder) self.real_basejtsrel_to_joints_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.real_basejtsrel_to_joints_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: real_basejtsrel_to_joints_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_to_joints_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_denoising_seqTransEncoderLayer, num_layers=self.num_layers) # self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) self.real_basejtsrel_to_joints_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # OutputProcessCond if self.diff_realbasejtsrel: # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder # self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions # self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim) self.rel_input_feats = self.nn_keypoints * (3 + 3 + 3) layernorm = True if self.args.use_objbase_v2: self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, glb_feats_trans=True) elif self.args.use_objbase_v4: # use_objbase_out_v4 self.rel_input_feats = (self.args.nn_base_pts * (3 + 3 + 3)) # current joint positions # how to keep the dimension self.real_basejtsrel_input_process = InputProcessObjBaseV4(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v5: # use_objbase_v5, if self.args.v5_in_not_base: self.rel_input_feats = (self.nn_keypoints * 3) elif self.args.v5_in_not_base_pos: self.rel_input_feats = 3 + (self.nn_keypoints * 3) else: self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb) elif self.args.use_objbase_v6: # real_basejtsrel_input_process self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) + 3 self.real_basejtsrel_input_process = InputProcessObjBaseV6(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v7: # InputProcessObjBaseV7 self.rel_input_feats = 3 + 3 + (self.nn_keypoints * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV7(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) else: self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) self.real_basejtsrel_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) ### Encoding layer ### real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # # dropout # # dropout nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_seqTransEncoderLayer, # basejtsrel_seqTrans num_layers=self.num_layers) ### timesteps embedding layer ### self.real_basejtsrel_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_sequence_pos_encoder) self.real_basejtsrel_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.real_basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: real_basejtsrel_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_denoising_seqTransEncoderLayer, num_layers=self.num_layers) print(f"not_cond_base: {self.args.not_cond_base}, latent_dim: {self.latent_dim}") if self.args.use_jts_pert_realbasejtsrel: print(f"use_jts_pert_realbasejtsrel!!!!!!") self.real_basejtsrel_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) else: if self.args.use_objbase_out_v3: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV3(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) elif self.args.use_objbase_out_v4: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV4(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) elif self.args.use_objbase_out_v5: # use_objbase_v5, use_objbase_out_v5 self.real_basejtsrel_output_process = OutputProcessObjBaseRawV5(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, out_objbase_v5_bundle_out=self.args.out_objbase_v5_bundle_out, v5_out_not_cond_base=self.args.v5_out_not_cond_base, nn_keypoints=self.nn_keypoints) else: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) # OutputProcessCond if self.diff_basejtsrel: # treate them as textures of signals to model # # base pts -> dec on base pts features --> # latent space denoising and feature decoding --> a little bit concern about the feature decoding process # # TODO: add base_pts and base_normals to the base points -rel-to- rhand joints encoding process # self.rel_input_feats = self.nn_keypoints * (3 + 3 + 3) # relative positions from base pts to rhand joints ## # self.avg_joints_sequence_input_process, self.avg_joint_sequence_output_process # self.avg_joints_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) if self.args.with_glb_info: # InputProcessWithGlbInfo self.joints_offset_input_process = InputProcessWithGlbInfo(self.data_rep, self.joints_feats_in_dim, self.latent_dim) else: self.joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) if self.args.not_cond_base: self.rel_input_feats = self.nn_keypoints * ( 3) # self.input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats+self.gru_emb_dim, self.latent_dim) self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec ### Encoding layer ### seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) ### Encoding layer ### # logvar_seqTransEncoder_e, logvar_seqTransEncoder # logvar_seqTranEncoder seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, num_layers=self.num_layers) ### timesteps embedding layer ### self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # baseptsrel # self.basejtsrel_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) self.sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) # seq_len x bsz x dim if self.args.const_noise: # add to attention network # # 1) max pool latents over the sequence # 2) transform the pooled latnets via the linear layer self.basejtsrel_glb_denoising_latents_trans_layer = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), nn.Linear(self.latent_dim * 2, self.latent_dim) ) ###### joints_feats_in_dim ###### # a linear transformation net with weights and bias set to zero # self.avg_joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) # output avgjts sequence # OutputProcessCond self.joint_offset_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, self.nn_keypoints, 3) if self.args.use_dec_rel_v2: self.output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) else: # OutputProcessObjBaseRaw ## output process for basejtsrel # self.output_process = OutputProcessObjBaseRaw(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) ##### ==== input process, communications, output process for rel, dists ==== ##### if self.diff_basejtse: ### input process obj base ### # construct input_process_e # # self.input_feats_e = 21 * (3 + 3 + 3 + 1 + 1) self.input_feats_e = self.nn_keypoints * (3 + 3 + 1 + 1) self.input_process_e = InputProcessObjBase(self.data_rep, self.input_feats_e+self.gru_emb_dim, self.latent_dim) self.sequence_pos_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec # # single layer transformers # ## predict relative position for each base point? # existing model # if self.arch == 'trans_enc': print("TRANS_ENC init") seqTransEncoderLayer_e = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e, num_layers=self.num_layers) print("TRANS_ENC init") # logvar_seqTransEncoder_e, seqTransEncoderLayer_e_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.logvar_seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e_logvar, num_layers=self.num_layers) # # elif self.arch == 'trans_dec': # print("TRANS_DEC init") # seqTransDecoderLayer_e = nn.TransformerDecoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) # self.seqTransDecoder_e = nn.TransformerDecoder(seqTransDecoderLayer_e, # num_layers=self.num_layers) # elif self.arch == 'gru': ## arch ## # print("GRU init") # self.gru_e = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) # else: # raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') # tiemstep # # timestep embedding e # Embed timestep e # self.embed_timestep_e = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) self.sequence_pos_denoising_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # self.basejtse_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) if self.args.use_ours_transformer_enc: # our transformer encoder # self.basejtse_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) ### basejtse_denoising_seqTransEncoder ### else: basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, num_layers=self.num_layers) # basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=self.activation) # self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, # num_layers=self.num_layers) # seq_len x bsz x dim if self.args.const_noise: # 1) max pool latents over the sequence # 2) transform the pooled latnets via the linear layer self.basejtse_denoising_seqTransEncoder = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), nn.Linear(self.latent_dim * 2, self.latent_dim) ) # self.output_process_e = OutputProcessObjBaseV3(self.data_rep, self.latent_dim) self.output_process_e = OutputProcessObjBaseERaw(self.data_rep, self.latent_dim) # self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) def set_enc_to_eval(self): # jts: joint_sequence_input_process, joint_sequence_pos_encoder, joint_sequence_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # basejtsrel: input_process, sequence_pos_encoder, seqTransEncoder, logvar_seqTransEncoder, # basejtse: input_process_e, sequence_pos_encoder_e, seqTransEncoder_e, logvar_seqTransEncoder_e if self.diff_jts: self.joint_sequence_input_process.eval() self.joint_sequence_pos_encoder.eval() self.joint_sequence_seqTransEncoder.eval() self.joint_sequence_logvar_seqTransEncoder.eval() if self.diff_basejtse: self.input_process_e.eval() self.sequence_pos_encoder_e.eval() self.seqTransEncoder_e.eval() self.logvar_seqTransEncoder_e.eval() if self.diff_basejtsrel: self.input_process.eval() self.sequence_pos_encoder.eval() self.seqTransEncoder.eval() # seqTransEncoder, logvar_seqTransEncoder self.logvar_seqTransEncoder.eval() def set_bn_to_eval(self): if self.args.use_objbase_v6: # real_basejtsrel_input_process try: self.real_basejtsrel_input_process.pnpp_conv_net.set_bn_no_training() except: pass def parameters_wo_clip(self): return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] def load_and_freeze_clip(self, clip_version): clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False) # Must set jit=False for training clip.model.convert_weights( # encode # ours float clip_model) # Actually this line is unnecessary since clip by default already on float16 ### ours # Freeze CLIP weights clip_model.eval() for p in clip_model.parameters(): p.requires_grad = False return clip_model def mask_cond(self, cond, force_mask=False): bs, d = cond.shape if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob > 0.: mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond return cond * (1. - mask) else: return cond def encode_text(self, raw_text): # raw_text - list (batch_size length) of strings with input text prompts # device = next(self.parameters()).device max_text_len = 20 if self.dataset in ['humanml', 'kit', 'motion_ours'] else None # Specific hardcoding for humanml dataset if max_text_len is not None: default_context_length = 77 context_length = max_text_len + 2 # start_token + 20 + end_token assert context_length < default_context_length texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate # print('texts', texts.shape) zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) texts = torch.cat([texts, zero_pad], dim=1) # print('texts after pad', texts.shape, texts) else: ## texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate return self.clip_model.encode_text(texts).float() def dec_jts_only_fr_latents(self, latents_feats): joint_seq_output = self.joint_sequence_output_process(latents_feats) # [bs, njoints, nfeats, nframes] # bsz x ws x nnj x nnfeats # --> joints_seq_outputs # joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() ## joints seq outputs ## diff_jts_dict = { "joint_seq_output": joint_seq_output, "joints_seq_latents": latents_feats, } return diff_jts_dict def dec_basejtsrel_only_fr_latents(self, latent_feats, x): # basejtsrel_seq_latents_pred_feats avg_jts_seq_latents = latent_feats[0:1, ...] other_basejtsrel_seq_latents = latent_feats[1:, ...] avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) basejtsrel_dec_out = { 'avg_jts_outputs': avg_jts_outputs, 'basejtsrel_output': basejtsrel_output['dec_rel'], } return basejtsrel_dec_out # def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): def dec_latents_to_joints_with_t(self, input_latent_feats, x, timesteps): # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # # joints_seq_latents: seq x bs x d --> perturbed joitns_seq_latents \in [-1, 1] ## # def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): ## positional encoding for denoising ## # rt_dict = { # 'joint_seq_output': joint_seq_output, # 'rel_base_pts_outputs': rel_base_pts_outputs, # } rt_dict = {} if self.diff_jts: ####### input latent feats ####### joints_seq_latents = input_latent_feats["joints_seq_latents"] if not self.args.without_dec_pos_emb: joints_seq_latents = self.joint_sequence_denoising_pos_encoder(joints_seq_latents) # ### GET joints seq time embeddings ### ### embed time stamps ### # joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) # joints_seq_latents = torch.cat( # [joints_seq_time_emb, joints_seq_latents], dim=0 # ) # joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # seq x bs x d if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) joints_seq_time_emb = joints_seq_time_emb.repeat(joints_seq_latents.size(0), 1, 1).contiguous() joints_seq_latents = joints_seq_latents + joints_seq_time_emb if self.args.use_ours_transformer_enc: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) joints_seq_latents = joints_seq_latents.permute(1, 0, 2) else: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) joints_seq_latents = torch.cat( [joints_seq_time_emb, joints_seq_latents], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) joints_seq_latents = joints_seq_latents.permute(1, 0, 2)[1:] else: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # joints_seq_latents: seq_len x bsz x latent_dim # if self.args.const_noise: seq_len = joints_seq_latents.size(0) # if self.args.const_noise: joints_seq_latents, _ = torch.max(joints_seq_latents, dim=0, keepdim=True) joints_seq_latents = self.glb_denoising_latents_trans_layer(joints_seq_latents) # seq_len x bsz x latent_dim joints_seq_latents = joints_seq_latents.repeat(seq_len, 1, 1).contiguous() ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### joints_seq_latents = input_latent_feats["joints_seq_latents_enc"] # bsz x ws x nnj x 3 # joint_seq_output = self.joint_sequence_output_process(joints_seq_latents) # [bs, njoints, nfeats, nframes] # bsz x ws x nnj x nnfeats # --> joints_seq_outputs # joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() diff_jts_dict = { "joint_seq_output": joint_seq_output, "joints_seq_latents": joints_seq_latents, } else: diff_jts_dict = {} if self.diff_basejtsrel: rel_base_pts_outputs = input_latent_feats["rel_base_pts_outputs"] if rel_base_pts_outputs.size(0) == 1 and self.args.single_frame_noise: rel_base_pts_outputs = rel_base_pts_outputs.repeat(self.args.window_size + 1, 1, 1) if not self.args.without_dec_pos_emb: # without avg_jts_inputs = rel_base_pts_outputs[0:1, ...] other_rel_base_pts_outputs = rel_base_pts_outputs[1: , ...] other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) rel_base_pts_outputs = torch.cat( [avg_jts_inputs, other_rel_base_pts_outputs], dim=0 ) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs.size(0), 1, 1).contiguous() basejtsrel_seq_latents = rel_base_pts_outputs + basejtsrel_time_emb if self.args.use_ours_transformer_enc: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_seq_latents = torch.cat( [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### basejtsrel_seq_latents = input_latent_feats["rel_base_pts_outputs_enc"] if basejtsrel_seq_latents.size(0) == 1 and self.args.single_frame_noise: basejtsrel_seq_latents = basejtsrel_seq_latents.repeat(self.args.window_size + 1, 1, 1) basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents elif self.args.pred_diff_noise: basejtsrel_seq_latents_pred_feats = input_latent_feats["rel_base_pts_outputs"] - basejtsrel_seq_latents else: basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) ### GET joints seq output ### # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) # basejtsrel_seq_latents = torch.cat( # [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 # ) # basejtsrel_seq_latents_pred_feats avg_jts_seq_latents = basejtsrel_seq_latents_pred_feats[0:1, ...] other_basejtsrel_seq_latents = basejtsrel_seq_latents_pred_feats[1:, ...] avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) #### # avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] # other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] # avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 # avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) diff_basejtsrel_dict = { "basejtsrel_output": basejtsrel_output['dec_rel'], "basejtsrel_seq_latents": basejtsrel_seq_latents, "avg_jts_outputs": avg_jts_outputs, } else: diff_basejtsrel_dict = {} if self.diff_basejtse: # e_disp_rel_to_base_along_normals = input_latent_feats['e_disp_rel_to_base_along_normals'] # e_disp_rel_to_baes_vt_normals = input_latent_feats['e_disp_rel_to_baes_vt_normals'] base_jts_e_feats = input_latent_feats['base_jts_e_feats'] # seq x bs x d --> e feats if not self.args.without_dec_pos_emb: # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) else: #### Decode e_along_normals and e_vt_normals #### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_feats = torch.cat( [base_jts_e_time_emb, base_jts_e_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] # base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) #### Decode e_along_normals and e_vt_normals #### ### bae jts e embeddings feats ### # base_jts_e_time_emb = self.embed_timestep_e(timesteps) # base_jts_e_feats = torch.cat( # [base_jts_e_time_emb, base_jts_e_feats], dim=0 # ) # base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] base_jts_e_output = self.output_process_e(base_jts_e_feats, x) # dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() #### Decode e_along_normals and e_vt_normals #### diff_basejtse_dict = { 'dec_e_along_normals': dec_e_along_normals, 'dec_e_vt_normals': dec_e_vt_normals, 'base_jts_e_feats': base_jts_e_feats, } else: diff_basejtse_dict = {} rt_dict = {} rt_dict.update(diff_jts_dict) rt_dict.update(diff_basejtsrel_dict) rt_dict.update(diff_basejtse_dict) ### rt_dict --> rt_dict of joints, rel ### return rt_dict # return joint_seq_output, joints_seq_latents def reparameterization(self, val_mean, val_var): val_noise = torch.randn_like(val_mean) val_sampled = val_mean + val_noise * val_var ### sample the value if self.args.rnd_noise: val_sampled = val_noise return val_sampled def decode_realbasejtsrel_from_objbasefeats(self, objbasefeats, input_data): real_dec_basejtsrel = self.real_basejtsrel_output_process( objbasefeats, input_data ) # real_dec_basejtsrel -> decoded realtive positions # real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() return real_dec_basejtsrel def denoising_realbasejtsrel_objbasefeats(self, pert_obj_base_pts_feats, timesteps): if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(pert_obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = pert_obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq des real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, pert_obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # bsz, nframes, nnb, nnj, 3 --> # # real_dec_basejtsrel = self.real_basejtsrel_output_process( # real_basejtsrel_seq_latents, x # ) return real_basejtsrel_seq_latents def forward(self, x, timesteps): """ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper timesteps: [batch_size] (int) """ # joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # # bsz, nframes, nnj = x['pert_rhand_joints'].shape[:3] # pert_rhand_joints = x['pert_rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws bsz, nframes, nnj = x['rhand_joints'].shape[:3] pert_rhand_joints = x['rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws base_pts = x['base_pts'] ### bsz x nnb x 3 ### base_normals = x['base_normals'] ### bsz x nnb x 3 ### --> base normals ### # base_normals # ## rt_dict = {} ## # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # if self.diff_basejtse: ### Embed physicss quantities ### # e_disp_rel_to_base_along_normals: bsz x (ws - 1) x nnj x nnb # # e_disp_rel_to_baes_vt_normals: bsz x (ws - 1) x nnj x nnb # # e_disp_rel_to_base_along_normals = x['e_disp_rel_to_base_along_normals'] # e_disp_rel_to_baes_vt_normals = x['e_disp_rel_to_baes_vt_normals'] e_disp_rel_to_base_along_normals = x['pert_e_disp_rel_to_base_along_normals'] e_disp_rel_to_baes_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] nnb = base_pts.size(1) disp_ws = e_disp_rel_to_base_along_normals.size(1) ### --> base normals ### base_pts_disp_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() base_normals_disp_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() # bsz x (ws - 1) x nnj x nnb x (3 + 3 + 1 + 1) base_pts_normals_e_in_feats = torch.cat( # along normals; # vt normals # [base_pts_disp_exp, base_normals_disp_exp, e_disp_rel_to_base_along_normals.unsqueeze(-1), e_disp_rel_to_baes_vt_normals.unsqueeze(-1)], dim=-1 ) base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.permute(0, 1, 3, 2, 4).contiguous() # bsz x (ws - 1) x nnb x (nnj x (xxx feats_dim)) base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.view(bsz, disp_ws, nnb, -1).contiguous() ## input process ## base_jts_e_feats = self.input_process_e(base_pts_normals_e_in_feats) base_jts_e_feats = self.sequence_pos_encoder_e(base_jts_e_feats) ## seq transformation for e ## # base_jts_e_feats_mean = self.seqTransEncoder_e(base_jts_e_feats) ## mean, mdm_ours ## # print(f"base_jts_e_feats: {base_jts_e_feats.size()}") ### Embed physicss quantities ### #### base_jts_e_feats, base_jts_e_feats_mean #### # ## us basejtsefeats for denoising directly ## base_jts_e_feats = base_jts_e_feats_mean if not self.args.without_dec_pos_emb: ## use positional encoding ## # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) else: #### Decode e_along_normals and e_vt_normals #### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_feats = torch.cat( [base_jts_e_time_emb, base_jts_e_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] # ### sequence latents ### # if self.args.train_enc: # trian enc for seq latents ### # base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] # base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) #### Decode e_along_normals and e_vt_normals #### ### bae jts e embeddings feats ### # base_jts_e_time_emb = self.embed_timestep_e(timesteps) # base_jts_e_feats = torch.cat( # [base_jts_e_time_emb, base_jts_e_feats], dim=0 # ) # base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] ##### output_process_e -> output energies ##### base_jts_e_output = self.output_process_e(base_jts_e_feats, x) # dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb #### Decode e_along_normals and e_vt_normals #### diff_basejtse_dict = { 'dec_e_along_normals': dec_e_along_normals, 'dec_e_vt_normals': dec_e_vt_normals, 'base_jts_e_feats': base_jts_e_feats, } # rt_dict['base_jts_e_feats'] = base_jts_e_feats # rt_dict['base_jts_e_feats_mean'] = base_jts_e_feats_mean # rt_dict['base_jts_e_feats_logvar'] = base_jts_e_feats_logvar # log_var # else: diff_basejtse_dict = {} if self.diff_jts: # base_pts_normal ### InputProcess ### pert_rhand_joints_trans = pert_rhand_joints.permute(0, 2, 3, 1).contiguous() # bsz x nnj x 3 x ws # rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints_trans) # [seqlen, bs, d] ### InputProcessObjBase ### # rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints) ### === Encode input joint sequences === ### # bs, njoints, nfeats, nframes = x.shape # rhand_joints_emb = self.joint_sequence_embed_timestep(timesteps) # [1, bs, d] # if self.arch == 'trans_enc': xseq = rhand_joints_feats # [seqlen+1, bs, d] xseq = self.joint_sequence_pos_encoder(xseq) # [seqlen+1, bs, d] joint_seq_output_mean = self.joint_sequence_seqTransEncoder(xseq) # [1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] ### calculate logvar, mean, and feats ### joint_seq_output_logvar = self.joint_sequence_logvar_seqTransEncoder(xseq) # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # joint_seq_output_var = torch.exp(joint_seq_output_logvar) # seq x bs x d --> encodeing and decoding ## base_jts_e_feats: seqlen x bs x d --> val latents ## joint_seq_output = self.reparameterization(joint_seq_output_mean, joint_seq_output_var) rt_dict['joint_seq_output'] = joint_seq_output # rt_dict['joint_seq_output'] = joint_seq_output_mean rt_dict['joint_seq_output_mean'] = joint_seq_output_mean rt_dict['joint_seq_output_logvar'] = joint_seq_output_logvar if self.args.diff_realbasejtsrel_to_joints: # nframes x nnbase x nnjts x (base pts + base normals + 3) 2) point feature for each point; point feature for; condition on the noisy input for the denoised information # real_basejtsrel_to_joints_input_process, real_basejtsrel_to_joints_sequence_pos_encoder, real_basejtsrel_to_joints_seqTransEncoder # real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_joints_for_jts_pred'].size()[:4] normed_base_pts = x['normed_base_pts'] base_normals = x['base_normals'] pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_joints_for_jts_pred'] # bsz x nf x nnj x nnb x 3 ## use_abs_jts_pos --> obj jts pos for the encodingj if self.args.use_abs_jts_pos: ## bsz x nf x nnj x nnb x 3 ## ---> abs jts pos ## pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process if self.args.use_abs_jts_for_encoding: if not self.args.use_abs_jts_pos: pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 abs_jts = pert_rel_base_pts_to_rhand_joints[..., 0, :] abs_jts = abs_jts.permute(0, 2, 3, 1).contiguous() obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(abs_jts) elif self.args.use_abs_jts_for_encoding_obj_base: if not self.args.use_abs_jts_pos: pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] # obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) # [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], pert_rel_base_pts_to_rhand_joints], dim=-1 # ) obj_base_in_feats = pert_rel_base_pts_to_rhand_joints # --> tnrasform the input feature dim to 21 * 3 here for encoding # obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, 1, -1).contiguous() # obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # else: if self.args.use_objbase_v2: # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # # ### real_basejtsrel_to_joints_input_process --> real_basejtsrel_to_joints_input_process --> for the joints and input process ### # obj_base_encoded_feats # obj_base_encoded_feats obj_base_pts_feats_pos_embedding = self.real_basejtsrel_to_joints_sequence_pos_encoder(obj_base_encoded_feats) obj_base_pts_feats = self.real_basejtsrel_to_joints_seqTransEncoder(obj_base_pts_feats_pos_embedding) if self.args.use_sigmoid: obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq des real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] joints_offset_output = self.real_basejtsrel_to_joints_output_process(real_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) if self.args.diff_basejtsrel: diff_basejtsrel_to_joints_dict = { 'joints_offset_output_from_rel': joints_offset_output } else: diff_basejtsrel_to_joints_dict = { 'joints_offset_output': joints_offset_output } else: diff_basejtsrel_to_joints_dict = {} if self.diff_realbasejtsrel: # real_dec_basejtsrel # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] normed_base_pts = x['normed_base_pts'] base_normals = x['base_normals'] pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 if self.args.use_objbase_v2: # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.use_objbase_v4: # use_objbase_v4: # use_objbase_out_v4 exp_normed_base_pts = normed_base_pts.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() exp_base_normals = base_normals.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() obj_base_in_feats = torch.cat( [pert_rel_base_pts_to_rhand_joints, exp_normed_base_pts, exp_base_normals], dim=-1 # bsz x nf x nnj x nnb x (3 + 3 + 3) # -> exp_base_normals ) obj_base_in_feats = obj_base_in_feats.view(bsz, nf, nnj, -1).contiguous() elif self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() if self.args.v5_in_not_base: obj_base_in_feats = torch.cat( [ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.v5_in_not_base_pos: obj_base_in_feats = torch.cat( [base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.use_objbase_v6 or self.args.use_objbase_v7: pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # # obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) # [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 # ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") # obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # if self.args.use_objbase_v6: normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1) # and repeat for the base pts # obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats, normed_base_pts_exp) else: obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats) # nf x bsz x feat_dim # nf x bsz x nnbasepts x feats_dim # # obj_base_encoded_feats obj_base_pts_feats_pos_embedding = self.real_basejtsrel_sequence_pos_encoder(obj_base_encoded_feats) obj_base_pts_feats = self.real_basejtsrel_seqTransEncoder(obj_base_pts_feats_pos_embedding) if self.args.use_sigmoid: obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. if self.args.train_enc: # basejtsrel_seq # bsz, nframes, nnb, nnj, 3 --> # real_dec_basejtsrel = self.real_basejtsrel_output_process( obj_base_pts_feats, x ) real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] if self.args.use_objbase_out_v3: real_dec_basejtsrel = real_dec_basejtsrel else: real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 else: if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### # print(f"timesteps: {timesteps.size()}, obj_base_pts_feats: {obj_base_pts_feats.size()}") if self.args.use_objbase_v5: cur_timesteps = timesteps.unsqueeze(1).repeat(1, nnb).view(-1) else: cur_timesteps = timesteps real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(cur_timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq des real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_seq # bsz, nframes, nnb, nnj, 3 --> # if self.args.use_jts_pert_realbasejtsrel: joints_offset_output = self.real_basejtsrel_output_process(real_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) # bsz x nf x nnj x 3 real_dec_basejtsrel = joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, nnb, 1) # real_dec_basejtsrel = joints_offset_output else: real_dec_basejtsrel = self.real_basejtsrel_output_process( real_basejtsrel_seq_latents, x ) # real_dec_basejtsrel -> decoded realtive positions # real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] if self.args.use_objbase_out_v3 or self.args.use_objbase_out_v4 or self.args.use_objbase_out_v5: real_dec_basejtsrel = real_dec_basejtsrel else: real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 diff_realbasejtsrel_out_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, 'obj_base_pts_feats': obj_base_pts_feats, } else: diff_realbasejtsrel_out_dict = {} # relative joints encoder; obj pos encoder; obj pos encoder; # penetrations, depth, --- how to use depth for guidance -> and also penetrations # object penetrations # if self.diff_basejtsrel: # joints_offset_sequence --> x['pert_joints_offset_sequence'] joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 joints_offset_sequence = joints_offset_sequence.permute(0, 2, 3, 1).contiguous() joints_offset_feats = self.joints_offset_input_process(joints_offset_sequence) # nf x bsz x dim # rel_base_pts_feats = self.input_process(basejtsrel_enc_in_feats) # sequence_pos_encoder rel_base_pts_feats_pos_embedding = self.sequence_pos_encoder(joints_offset_feats) # print(f"joints_offset_feats: {joints_offset_feats.size()}, rel_base_pts_feats_pos_embedding: {rel_base_pts_feats_pos_embedding.size()}") # outputs rel base jts encoded latents ## # seqTransEncoder, logvar_seqTransEncoder # rel_base_pts_outputs_mean = self.basejtsrel_denoising_seqTransEncoder(rel_base_pts_feats_pos_embedding) # ### calculate logvar, mean, and feats ### # rel_base_pts_outputs_logvar = self.joint_sequence_logvar_seqTransEncoder(rel_base_pts_outputs) if self.args.not_diff_avgjts: # not use diff avgjts ## rel_base_pts_feats = rel_base_pts_feats_pos_embedding # seqTransEncoder, logvar_seqTransEncoder # rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) # print(f"rel_base_pts_outputs_mean 1: {rel_base_pts_outputs_mean.size()}") if not self.args.without_dec_pos_emb: # without dec pos embedding # avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] other_rel_base_pts_outputs = rel_base_pts_outputs_mean # other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) rel_base_pts_outputs_mean = other_rel_base_pts_outputs if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb ### time embeddings and relbaseptsoutputs # print(f"basejtsrel_seq_latents: {basejtsrel_seq_latents.size()}, rel_base_pts_outputs_mean: {rel_base_pts_outputs_mean.size()}, basejtsrel_time_emb: {basejtsrel_time_emb.size()}") if self.args.use_ours_transformer_enc: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_seq_latents = torch.cat( [basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_seq_latents_pred_feats # avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] other_basejtsrel_seq_latents = basejtsrel_seq_latents # [1:, ...] joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) # print(f"joints_offset_output in MDM: {joints_offset_output.size()}, joints_offset_sequence: {joints_offset_sequence.size()}, other_basejtsrel_seq_latents: {other_basejtsrel_seq_latents.size()}") diff_basejtsrel_dict = { 'joints_offset_output': joints_offset_output } # rt_dict['joints_offset_output'] = joints_offset_output else: avg_joints_sequence = x['pert_avg_joints_sequence'] avg_joints_sequence_trans = avg_joints_sequence.unsqueeze(-1) avg_joints_feats = self.avg_joints_sequence_input_process(avg_joints_sequence_trans) ## 1 x bsz x dim ### rel_base_pts_feats = torch.cat( # (seq_len + 1) x bsz x dim # [avg_joints_feats, rel_base_pts_feats_pos_embedding], dim=0 ## jrel_base_pts_pos_embedding # ) ## joints embedding for mean statistics and logvar statistics ## # seqTransEncoder, logvar_seqTransEncoder # rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) if not self.args.without_dec_pos_emb: # without dec pos embedding avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] other_rel_base_pts_outputs = rel_base_pts_outputs_mean[1: , ...] other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) rel_base_pts_outputs_mean = torch.cat( [avg_jts_inputs, other_rel_base_pts_outputs], dim=0 ) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb ### time embeddings and relbaseptsoutputs if self.args.use_ours_transformer_enc: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_seq_latents = torch.cat( [basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_seq_latents_pred_feats avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) rt_dict['joints_offset_output'] = joints_offset_output rt_dict['avg_jts_outputs'] = avg_jts_outputs else: diff_basejtsrel_dict = {} rt_dict = {} rt_dict.update(diff_basejtsrel_dict) rt_dict.update(diff_basejtse_dict) ### rt_dict and diff_basejtse rt_dict.update(diff_basejtsrel_to_joints_dict) rt_dict.update(diff_realbasejtsrel_out_dict) ### diff return rt_dict def _apply(self, fn): super()._apply(fn) # self.rot2xyz.smpl_model._apply(fn) def train(self, *args, **kwargs): super().train(*args, **kwargs) # self.rot2xyz.smpl_model.train(*args, **kwargs) ### MDM 10 ### class MDMV12(nn.Module): def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot, latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1, ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512, arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs): super().__init__() self.legacy = legacy self.modeltype = modeltype self.njoints = njoints self.nfeats = nfeats self.num_actions = num_actions self.data_rep = data_rep self.dataset = dataset self.pose_rep = pose_rep self.glob = glob self.glob_rot = glob_rot self.translation = translation self.latent_dim = latent_dim self.ff_size = ff_size self.num_layers = num_layers self.num_heads = num_heads self.dropout = dropout self.ablation = ablation self.activation = activation self.clip_dim = clip_dim self.action_emb = kargs.get('action_emb', None) self.input_feats = self.njoints * self.nfeats self.normalize_output = kargs.get('normalize_encoder_output', False) self.cond_mode = kargs.get('cond_mode', 'no_cond') self.cond_mask_prob = kargs.get('cond_mask_prob', 0.) # ### GET args ### self.args = kargs.get('args', None) ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel self.diff_realbasejtsrel_to_joints = self.args.diff_realbasejtsrel_to_joints ### GET the diff. suit ### self.arch = arch ## ==== gru_emb_dim ==== ## # gru emb dim # self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0 # # joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # ###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### self.joints_feats_in_dim = 21 * 3 self.data_rep = "xyz" if self.diff_jts: ## Input process for joints ## self.joint_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) # # InputProcessObjBase(self, data_rep, input_feats, latent_dim) # self.joint_sequence_input_process = InputProcessObjBase(self.data_rep, 3, self.latent_dim) # self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim) self.joint_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec # if self.arch == 'trans_enc': print("TRANS_ENC init") ## transformer encoder layer ## UNet seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) ## Joint sequence transformer encoder layer ## # sequence encoder # self.joint_sequence_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) ### logvar for the encoding laeyer and # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) ## Joint sequence transformer encoder layer ## # sequence encoder # self.joint_sequence_logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, num_layers=self.num_layers) # elif self.arch == 'trans_dec': # print("TRANS_DEC init") # seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # num_heads # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) # self.joint_sequence_seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, # num_layers=self.num_layers) # elif self.arch == 'gru': # print("GRU init") # self.joint_sequence_gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) # else: # raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') ### joint sequence embed timestep ## ## timestep self.joint_sequence_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) # self.joint_sequence_output_process = OutputProcess(self.data_rep, self.latent_dim) # (self, data_rep, input_feats, latent_dim, njoints, nfeats): #### ====== joint sequence denoising block ====== #### ## seqTransEncoder ## self.joint_sequence_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.joint_sequence_pos_encoder) self.joint_sequence_denoising_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: self.joint_sequence_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) # seq_len x bsz x dim if self.args.const_noise: # 1) max pool latents over the sequence # 2) transform the pooled latnets via the linear layer self.glb_denoising_latents_trans_layer = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), nn.Linear(self.latent_dim * 2, self.latent_dim) ) # refinement for predicted joints # --> not in the paradigm of generation # # seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=self.activation) # self.joint_sequence_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, # num_layers=self.num_layers) #### ====== joint sequence denoiisng block ====== #### ### Output process ### output proces for joint sequence ### # output proces --> datarep, joints feats in dim, latent dim ## ###### joints_feats_in_dim ###### self.joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # OutputProcessCond # self.joint_sequence_output_process = OutputProcessCond(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) ###### ======= Construct joint sequence encoder, communicator, and decoder ======== ####### # real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process if self.diff_realbasejtsrel_to_joints: # feature for each joint point? --> for the denoising purpose # # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder layernorm = True self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions if self.args.use_abs_jts_for_encoding_obj_base: self.rel_input_feats = 21 * (3) # layernorm = False self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) # self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) # elif self.args.use else: if self.args.use_objbase_v2: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v3: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_to_joints_input_process = InputProcessObjBaseV3(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) else: self.real_basejtsrel_to_joints_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) if self.args.use_abs_jts_for_encoding: # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process self.real_basejtsrel_to_joints_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) self.real_basejtsrel_to_joints_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) ### Encoding layer ### # InputProcessObjBaseV2 real_basejtsrel_to_joints_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_to_joints_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_seqTransEncoderLayer, # basejtsrel_seqTrans num_layers=self.num_layers) ### timesteps embedding layer ### self.real_basejtsrel_to_joints_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_to_joints_sequence_pos_encoder) self.real_basejtsrel_to_joints_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.real_basejtsrel_to_joints_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: real_basejtsrel_to_joints_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_to_joints_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_to_joints_denoising_seqTransEncoderLayer, num_layers=self.num_layers) # self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) self.real_basejtsrel_to_joints_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # OutputProcessCond if self.diff_realbasejtsrel: # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder self.rel_input_feats = 21 * (3 + 3 + 3) # base pts, normals, the relative positions # self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim) layernorm = True if self.args.use_objbase_v2: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV2(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, glb_feats_trans=True) elif self.args.use_objbase_v4: # use_objbase_out_v4 self.rel_input_feats = (self.args.nn_base_pts * (3 + 3 + 3)) # current joint positions # how to keep the dimension self.real_basejtsrel_input_process = InputProcessObjBaseV4(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v5: # use_objbase_v5, if self.args.v5_in_not_base: self.rel_input_feats = (21 * 3) elif self.args.v5_in_not_base_pos: self.rel_input_feats = 3 + (21 * 3) else: self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb) elif self.args.use_objbase_v6: # real_basejtsrel_input_process self.rel_input_feats = 3 + 3 + (21 * 3) + 3 self.real_basejtsrel_input_process = InputProcessObjBaseV6(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) elif self.args.use_objbase_v7: # InputProcessObjBaseV7 self.rel_input_feats = 3 + 3 + (21 * 3) self.real_basejtsrel_input_process = InputProcessObjBaseV7(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) else: self.real_basejtsrel_input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm) self.real_basejtsrel_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) ### Encoding layer ### real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # latent dim # nn_heads # ff_size # dropout # # dropout # # dropout nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_seqTransEncoderLayer, # basejtsrel_seqTrans num_layers=self.num_layers) ### timesteps embedding layer ### self.real_basejtsrel_embed_timestep = TimestepEmbedder(self.latent_dim, self.real_basejtsrel_sequence_pos_encoder) self.real_basejtsrel_sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.real_basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: real_basejtsrel_denoising_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.real_basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(real_basejtsrel_denoising_seqTransEncoderLayer, num_layers=self.num_layers) print(f"not_cond_base: {self.args.not_cond_base}, latent_dim: {self.latent_dim}") if self.args.use_jts_pert_realbasejtsrel: print(f"use_jts_pert_realbasejtsrel!!!!!!") self.real_basejtsrel_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) else: if self.args.use_objbase_out_v3: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV3(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) elif self.args.use_objbase_out_v4: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV4(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) elif self.args.use_objbase_out_v5: # use_objbase_v5, use_objbase_out_v5 self.real_basejtsrel_output_process = OutputProcessObjBaseRawV5(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, out_objbase_v5_bundle_out=self.args.out_objbase_v5_bundle_out, v5_out_not_cond_base=self.args.v5_out_not_cond_base) else: self.real_basejtsrel_output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) # OutputProcessCond if self.diff_basejtsrel: ## basejtsrel ## # treate them as textures of signals to model # # base pts -> dec on base pts features --> # latent space denoising and feature decoding --> a little bit concern about the feature decoding process # # TODO: add base_pts and base_normals to the base points -rel-to- rhand joints encoding process # self.rel_input_feats = 21 * (3 + 3 + 3) # relative positions from base pts to rhand joints ## # cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # # Conditional strategy 1: use the relative position embeddings for guidance (cannot use the origianl weights for finetuning) # layernorm = True # [finetune_with_cond_rel, finetune_with_cond_jtsobj] if self.args.finetune_with_cond_rel: if self.args.use_objbase_v5: # use_objbase_v5 # if self.args.v5_in_not_base: self.rel_input_feats = (21 * 3) elif self.args.v5_in_not_base_pos: self.rel_input_feats = 3 + (21 * 3) else: # as an additional conditions for the input and denoising # self.rel_input_feats = 3 + 3 + (21 * 3) self.cond_real_basejtsrel_input_process = InputProcessObjBaseV5(self.data_rep, self.rel_input_feats, self.latent_dim, layernorm=layernorm, without_glb=self.args.v5_in_without_glb, only_with_glb=True) else: raise ValueError(f"Must use objbase_v5 currently, others have not been implemented yet.") self.cond_real_basejtsrel_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) cond_real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.cond_real_basejtsrel_seqTransEncoderLayer = nn.TransformerEncoder(cond_real_basejtsrel_seqTransEncoderLayer, num_layers=self.num_layers) elif self.args.finetune_with_cond_jtsobj: # cond_obj_trans_layer # InputProcessObjV6(self, data_rep, input_feats, latent_dim, layernorm=True) # cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_jtsobj_seqTransEncoder # # TODO: cond_obj_trans_layer : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder self.cond_obj_input_feats = 3 # if self.args.finetune_cond_obj_feats_dim == 3: # self.cond_obj_input_feats = 3 self.cond_obj_input_feats = self.args.finetune_cond_obj_feats_dim # finetune cond with obj feats dim # self.cond_obj_input_layer = InputProcessObjV6(self.data_rep, self.cond_obj_input_feats, self.latent_dim, layernorm=layernorm) # TODO: remember to set this layer to zero! # self.cond_obj_trans_layer = nn.Linear(self.latent_dim, self.latent_dim) # hand_embedding + zero_trans_layer(obj_cond_embedding) self.cond_joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) self.cond_sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) cond_jtsobj_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.cond_seqTransEncoder = nn.TransformerEncoder(cond_jtsobj_seqTransEncoderLayer, num_layers=self.num_layers) else: raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated!") # TODO: remember to initialize it to zero! # self.cond_trans_linear_layer = nn.Linear(self.latent_dim, self.latent_dim) # self.avg_joints_sequence_input_process, self.avg_joint_sequence_output_process # self.avg_joints_sequence_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) # TODO: should set those weights to frozen --> joints offset input process... self.joints_offset_input_process = InputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim) if self.args.not_cond_base: self.rel_input_feats = 21 * ( 3) # self.input_process = InputProcessObjBase(self.data_rep, self.rel_input_feats+self.gru_emb_dim, self.latent_dim) self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec ### Encoding layer ### seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) ### Encoding layer ### # logvar_seqTransEncoder_e, logvar_seqTransEncoder # logvar_seqTranEncoder seqTransEncoderLayer_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.logvar_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer_logvar, num_layers=self.num_layers) ### timesteps embedding layer ### # TimestepEmbedder -> embedding times # self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # baseptsrel # self.basejtsrel_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) self.sequence_pos_denoising_encoder = PositionalEncoding(self.latent_dim, self.dropout) if self.args.use_ours_transformer_enc: # our transformer encoder # self.basejtsrel_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) else: seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.basejtsrel_denoising_seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=self.num_layers) # # seq_len x bsz x dim # if self.args.const_noise: # add to attention network # add j # # 1) max pool latents over the sequence # # 2) transform the pooled latnets via the linear layer # self.basejtsrel_glb_denoising_latents_trans_layer = nn.Sequential( # nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), # nn.Linear(self.latent_dim * 2, self.latent_dim) # ) ###### joints_feats_in_dim ###### # a linear transformation net with weights and bias set to zero # self.avg_joint_sequence_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) # output avgjts sequence # OutputProcessCond self.joint_offset_output_process = OutputProcess(self.data_rep, self.joints_feats_in_dim, self.latent_dim, 21, 3) if self.args.use_dec_rel_v2: self.output_process = OutputProcessObjBaseRawV2(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base, finetune_with_cond=self.args.finetune_with_cond) else: # OutputProcessObjBaseRaw ## output process for basejtsrel # self.output_process = OutputProcessObjBaseRaw(self.data_rep, self.latent_dim, not_cond_base=self.args.not_cond_base) ##### ==== input process, communications, output process for rel, dists ==== ##### if self.diff_basejtse: ### input process obj base ### # construct input_process_e # # self.input_feats_e = 21 * (3 + 3 + 3 + 1 + 1) # self.input_feats_e = 21 * (3 + 3 + 1 + 1) self.input_feats_e = 21 * (3 + 3 + 1 + 1 + 3 + 3 + 1) self.input_process_e = InputProcessObjBase(self.data_rep, self.input_feats_e+self.gru_emb_dim, self.latent_dim) self.sequence_pos_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) self.emb_trans_dec = emb_trans_dec # # single layer transformers # ## predict relative position for each base point? # existing model # if self.arch == 'trans_enc': print("TRANS_ENC init") seqTransEncoderLayer_e = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e, num_layers=self.num_layers) print("TRANS_ENC init") # logvar_seqTransEncoder_e, seqTransEncoderLayer_e_logvar = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.logvar_seqTransEncoder_e = nn.TransformerEncoder(seqTransEncoderLayer_e_logvar, num_layers=self.num_layers) # # elif self.arch == 'trans_dec': # print("TRANS_DEC init") # seqTransDecoderLayer_e = nn.TransformerDecoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=activation) # self.seqTransDecoder_e = nn.TransformerDecoder(seqTransDecoderLayer_e, # num_layers=self.num_layers) # elif self.arch == 'gru': ## arch ## # print("GRU init") # self.gru_e = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True) # else: # raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]') # tiemstep # # timestep embedding e # Embed timestep e # self.embed_timestep_e = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) self.sequence_pos_denoising_encoder_e = PositionalEncoding(self.latent_dim, self.dropout) # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # self.basejtse_denoising_embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder_e) if self.args.use_ours_transformer_enc: # our transformer encoder # self.basejtse_denoising_seqTransEncoder = model_utils.TransformerEncoder( hidden_size=self.latent_dim, fc_size=self.ff_size, num_heads=self.num_heads, layer_norm=True, num_layers=self.num_layers, dropout_rate=0.2, re_zero=True, memory_efficient=False, ) ### basejtse_denoising_seqTransEncoder ### else: basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, nhead=self.num_heads, dim_feedforward=self.ff_size, dropout=self.dropout, activation=self.activation) self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, num_layers=self.num_layers) # basejtse_seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, # nhead=self.num_heads, # dim_feedforward=self.ff_size, # dropout=self.dropout, # activation=self.activation) # self.basejtse_denoising_seqTransEncoder = nn.TransformerEncoder(basejtse_seqTransEncoderLayer, # num_layers=self.num_layers) # seq_len x bsz x dim # if self.args.const_noise: # # 1) max pool latents over the sequence # # 2) transform the pooled latnets via the linear layer # self.basejtse_denoising_seqTransEncoder = nn.Sequential( # nn.Linear(self.latent_dim, self.latent_dim * 2), nn.ReLU(), # nn.Linear(self.latent_dim * 2, self.latent_dim) # ) # self.output_process_e = OutputProcessObjBaseV3(self.data_rep, self.latent_dim) self.output_process_e = OutputProcessObjBaseERaw(self.data_rep, self.latent_dim) # self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset) def set_enc_to_eval(self): # jts: joint_sequence_input_process, joint_sequence_pos_encoder, joint_sequence_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # basejtsrel: input_process, sequence_pos_encoder, seqTransEncoder, logvar_seqTransEncoder, # basejtse: input_process_e, sequence_pos_encoder_e, seqTransEncoder_e, logvar_seqTransEncoder_e, if self.diff_jts: self.joint_sequence_input_process.eval() self.joint_sequence_pos_encoder.eval() self.joint_sequence_seqTransEncoder.eval() self.joint_sequence_logvar_seqTransEncoder.eval() if self.diff_basejtse: self.input_process_e.eval() self.sequence_pos_encoder_e.eval() self.seqTransEncoder_e.eval() self.logvar_seqTransEncoder_e.eval() if self.diff_basejtsrel: self.input_process.eval() self.sequence_pos_encoder.eval() self.seqTransEncoder.eval() # seqTransEncoder, logvar_seqTransEncoder self.logvar_seqTransEncoder.eval() def set_bn_to_eval(self): if self.args.use_objbase_v6: # real_basejtsrel_input_process try: self.real_basejtsrel_input_process.pnpp_conv_net.set_bn_no_training() except: pass def parameters_wo_clip(self): return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')] def load_and_freeze_clip(self, clip_version): clip_model, clip_preprocess = clip.load(clip_version, device='cpu', jit=False) # Must set jit=False for training clip.model.convert_weights( # encode clip_model) # Actually this line is unnecessary since clip by default already on float16 # Freeze CLIP weights clip_model.eval() for p in clip_model.parameters(): p.requires_grad = False return clip_model def mask_cond(self, cond, force_mask=False): bs, d = cond.shape if force_mask: return torch.zeros_like(cond) elif self.training and self.cond_mask_prob > 0.: mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond return cond * (1. - mask) else: return cond def encode_text(self, raw_text): # raw_text - list (batch_size length) of strings with input text prompts # device = next(self.parameters()).device max_text_len = 20 if self.dataset in ['humanml', 'kit', 'motion_ours'] else None # Specific hardcoding for humanml dataset if max_text_len is not None: default_context_length = 77 context_length = max_text_len + 2 # start_token + 20 + end_token assert context_length < default_context_length texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate # print('texts', texts.shape) zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) texts = torch.cat([texts, zero_pad], dim=1) # print('texts after pad', texts.shape, texts) else: ## texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate return self.clip_model.encode_text(texts).float() def dec_jts_only_fr_latents(self, latents_feats): joint_seq_output = self.joint_sequence_output_process(latents_feats) # [bs, njoints, nfeats, nframes] # bsz x ws x nnj x nnfeats # --> joints_seq_outputs # joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() ## joints seq outputs ## diff_jts_dict = { "joint_seq_output": joint_seq_output, "joints_seq_latents": latents_feats, } return diff_jts_dict def dec_basejtsrel_only_fr_latents(self, latent_feats, x): # basejtsrel_seq_latents_pred_feats avg_jts_seq_latents = latent_feats[0:1, ...] other_basejtsrel_seq_latents = latent_feats[1:, ...] avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) basejtsrel_dec_out = { 'avg_jts_outputs': avg_jts_outputs, 'basejtsrel_output': basejtsrel_output['dec_rel'], } return basejtsrel_dec_out # def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): def dec_latents_to_joints_with_t(self, input_latent_feats, x, timesteps): # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # # joints_seq_latents: seq x bs x d --> perturbed joitns_seq_latents \in [-1, 1] ## # def dec_latents_to_joints_with_t(self, joints_seq_latents, timesteps): ## positional encoding for denoising ## # rt_dict = { # 'joint_seq_output': joint_seq_output, # 'rel_base_pts_outputs': rel_base_pts_outputs, # } rt_dict = {} if self.diff_jts: ####### input latent feats ####### joints_seq_latents = input_latent_feats["joints_seq_latents"] if not self.args.without_dec_pos_emb: joints_seq_latents = self.joint_sequence_denoising_pos_encoder(joints_seq_latents) # ### GET joints seq time embeddings ### ### embed time stamps ### # joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) # joints_seq_latents = torch.cat( # [joints_seq_time_emb, joints_seq_latents], dim=0 # ) # joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # seq x bs x d if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) joints_seq_time_emb = joints_seq_time_emb.repeat(joints_seq_latents.size(0), 1, 1).contiguous() joints_seq_latents = joints_seq_latents + joints_seq_time_emb if self.args.use_ours_transformer_enc: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) joints_seq_latents = joints_seq_latents.permute(1, 0, 2) else: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### joints_seq_time_emb = self.joint_sequence_embed_timestep(timesteps) joints_seq_latents = torch.cat( [joints_seq_time_emb, joints_seq_latents], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) joints_seq_latents = joints_seq_latents.permute(1, 0, 2)[1:] else: joints_seq_latents = self.joint_sequence_denoising_seqTransEncoder(joints_seq_latents)[1:] # joints_seq_latents: seq_len x bsz x latent_dim # if self.args.const_noise: seq_len = joints_seq_latents.size(0) # if self.args.const_noise: joints_seq_latents, _ = torch.max(joints_seq_latents, dim=0, keepdim=True) joints_seq_latents = self.glb_denoising_latents_trans_layer(joints_seq_latents) # seq_len x bsz x latent_dim joints_seq_latents = joints_seq_latents.repeat(seq_len, 1, 1).contiguous() ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### joints_seq_latents = input_latent_feats["joints_seq_latents_enc"] # bsz x ws x nnj x 3 # joint_seq_output = self.joint_sequence_output_process(joints_seq_latents) # [bs, njoints, nfeats, nframes] # bsz x ws x nnj x nnfeats # --> joints_seq_outputs # joint_seq_output = joint_seq_output.permute(0, 3, 1, 2).contiguous() diff_jts_dict = { "joint_seq_output": joint_seq_output, "joints_seq_latents": joints_seq_latents, } else: diff_jts_dict = {} if self.diff_basejtsrel: rel_base_pts_outputs = input_latent_feats["rel_base_pts_outputs"] if rel_base_pts_outputs.size(0) == 1 and self.args.single_frame_noise: rel_base_pts_outputs = rel_base_pts_outputs.repeat(self.args.window_size + 1, 1, 1) if not self.args.without_dec_pos_emb: # without avg_jts_inputs = rel_base_pts_outputs[0:1, ...] other_rel_base_pts_outputs = rel_base_pts_outputs[1: , ...] other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) rel_base_pts_outputs = torch.cat( [avg_jts_inputs, other_rel_base_pts_outputs], dim=0 ) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs.size(0), 1, 1).contiguous() basejtsrel_seq_latents = rel_base_pts_outputs + basejtsrel_time_emb if self.args.use_ours_transformer_enc: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_seq_latents = torch.cat( [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### basejtsrel_seq_latents = input_latent_feats["rel_base_pts_outputs_enc"] if basejtsrel_seq_latents.size(0) == 1 and self.args.single_frame_noise: basejtsrel_seq_latents = basejtsrel_seq_latents.repeat(self.args.window_size + 1, 1, 1) basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents elif self.args.pred_diff_noise: basejtsrel_seq_latents_pred_feats = input_latent_feats["rel_base_pts_outputs"] - basejtsrel_seq_latents else: basejtsrel_seq_latents_pred_feats = basejtsrel_seq_latents # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) ### GET joints seq output ### # basejtsrel_denoising_embed_timestep, basejtsrel_denoising_seqTransEncoder, output_process # # basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) # basejtsrel_seq_latents = torch.cat( # [basejtsrel_time_emb, rel_base_pts_outputs], dim=0 # ) # basejtsrel_seq_latents_pred_feats avg_jts_seq_latents = basejtsrel_seq_latents_pred_feats[0:1, ...] other_basejtsrel_seq_latents = basejtsrel_seq_latents_pred_feats[1:, ...] avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) #### # avg_jts_seq_latents = basejtsrel_seq_latents[0:1, ...] # other_basejtsrel_seq_latents = basejtsrel_seq_latents[1:, ...] # avg_jts_outputs = self.avg_joint_sequence_output_process(avg_jts_seq_latents) # bsz x njoints x nfeats x 1 # avg_jts_outputs = avg_jts_outputs.squeeze(-1) # bsz x nnjoints x 1 --> avg joints here # # # basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_output = self.output_process(other_basejtsrel_seq_latents, x) diff_basejtsrel_dict = { "basejtsrel_output": basejtsrel_output['dec_rel'], "basejtsrel_seq_latents": basejtsrel_seq_latents, "avg_jts_outputs": avg_jts_outputs, } else: diff_basejtsrel_dict = {} if self.diff_basejtse: # e_disp_rel_to_base_along_normals = input_latent_feats['e_disp_rel_to_base_along_normals'] # e_disp_rel_to_baes_vt_normals = input_latent_feats['e_disp_rel_to_baes_vt_normals'] base_jts_e_feats = input_latent_feats['base_jts_e_feats'] # seq x bs x d --> e feats if not self.args.without_dec_pos_emb: # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) else: #### Decode e_along_normals and e_vt_normals #### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_feats = torch.cat( [base_jts_e_time_emb, base_jts_e_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] ### sequence latents ### if self.args.train_enc: # trian enc for seq latents ### base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] # base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) #### Decode e_along_normals and e_vt_normals #### ### bae jts e embeddings feats ### # base_jts_e_time_emb = self.embed_timestep_e(timesteps) # base_jts_e_feats = torch.cat( # [base_jts_e_time_emb, base_jts_e_feats], dim=0 # ) # base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] base_jts_e_output = self.output_process_e(base_jts_e_feats, x) # dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] dec_d = basejtsrel_output['dec_d'] rel_vel_dec = basejtsrel_output['rel_vel_dec'] dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() dec_d = dec_d.contiguous().permute(0, 1, 3, 2).contiguous() rel_vel_dec = rel_vel_dec.contiguous().permute(0, 1, 3, 2).contiguous() #### Decode e_along_normals and e_vt_normals #### diff_basejtse_dict = { 'dec_e_along_normals': dec_e_along_normals, 'dec_e_vt_normals': dec_e_vt_normals, 'base_jts_e_feats': base_jts_e_feats, 'dec_d': dec_d, 'rel_vel_dec': rel_vel_dec, } else: diff_basejtse_dict = {} rt_dict = {} rt_dict.update(diff_jts_dict) rt_dict.update(diff_basejtsrel_dict) rt_dict.update(diff_basejtse_dict) ### rt_dict --> rt_dict of joints, rel ### return rt_dict # return joint_seq_output, joints_seq_latents def reparameterization(self, val_mean, val_var): val_noise = torch.randn_like(val_mean) val_sampled = val_mean + val_noise * val_var ### sample the value if self.args.rnd_noise: val_sampled = val_noise return val_sampled def decode_realbasejtsrel_from_objbasefeats(self, objbasefeats, input_data): real_dec_basejtsrel = self.real_basejtsrel_output_process( objbasefeats, input_data ) # real_dec_basejtsrel -> decoded realtive positions # real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() return real_dec_basejtsrel def denoising_realbasejtsrel_objbasefeats(self, pert_obj_base_pts_feats, timesteps): if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(pert_obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = pert_obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, pert_obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # bsz, nframes, nnb, nnj, 3 --> # real_dec_basejtsrel = self.real_basejtsrel_output_process( # real_basejtsrel_seq_latents, x # ) return real_basejtsrel_seq_latents def get_cond_parameters(self): # [finetune_with_cond_rel, finetune_with_cond_jtsobj] # cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # # finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder # cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_seqTransEncoder # if self.args.diff_basejtsrel: # TODO: finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder if self.args.finetune_with_cond_rel: params = list(self.cond_real_basejtsrel_input_process.parameters()) + list(self.cond_real_basejtsrel_pos_encoder.parameters()) + list(self.cond_real_basejtsrel_seqTransEncoderLayer.parameters()) + list(self.cond_trans_linear_layer.parameters()) elif self.args.finetune_with_cond_jtsobj: params = list(self.cond_joints_offset_input_process.parameters()) + list(self.cond_sequence_pos_encoder.parameters()) + list(self.cond_seqTransEncoder.parameters()) + list(self.cond_obj_trans_layer.parameters()) + list(self.cond_obj_input_layer.parameters()) else: raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated !") else: raise ValueError(f"Must use diff_basejtsrel currently, others have not been implemented yet.") return params def set_trans_linear_layer_to_zero(self): torch.nn.init.zeros_(self.cond_trans_linear_layer.weight) torch.nn.init.zeros_(self.cond_trans_linear_layer.bias) # finetune_with_cond_jtsobj: # cond_obj_trans_layer if self.args.finetune_with_cond_jtsobj: torch.nn.init.zeros_(self.cond_obj_trans_layer.weight) torch.nn.init.zeros_(self.cond_obj_trans_layer.bias) def forward(self, x, timesteps): """ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper timesteps: [batch_size] (int) """ # joint_sequence_input_process; joint_sequence_pos_encoder; joint_sequence_seqTransEncoder; joint_sequence_seqTransDecoder; joint_sequence_embed_timestep; joint_sequence_output_process # # bsz, nframes, nnj = x['pert_rhand_joints'].shape[:3] # pert_rhand_joints = x['pert_rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws bsz, nframes, nnj = x['rhand_joints'].shape[:3] pert_rhand_joints = x['rhand_joints'] # bsz x ws x nnj x 3 # bsz x nnj x 3 x ws base_pts = x['base_pts'] ### bsz x nnb x 3 ### base_normals = x['base_normals'] ### bsz x nnb x 3 ### --> base normals ### # base_normals # ## rt_dict = {} ## # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # if self.diff_basejtse: ### Embed physicss quantities ### # e_disp_rel_to_base_along_normals: bsz x (ws - 1) x nnj x nnb # # e_disp_rel_to_baes_vt_normals: bsz x (ws - 1) x nnj x nnb # # e_disp_rel_to_base_along_normals = x['e_disp_rel_to_base_along_normals'] # e_disp_rel_to_baes_vt_normals = x['e_disp_rel_to_baes_vt_normals'] e_disp_rel_to_base_along_normals = x['pert_e_disp_rel_to_base_along_normals'] e_disp_rel_to_baes_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] obj_pts_disp = x['obj_pts_disp'] vel_obj_pts_to_hand_pts = x['vel_obj_pts_to_hand_pts'] disp_dist = x['disp_dist'] nnb = base_pts.size(1) disp_ws = e_disp_rel_to_base_along_normals.size(1) ### --> base normals ### base_pts_disp_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() base_normals_disp_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() obj_pts_disp_exp = obj_pts_disp.unsqueeze(1).unsqueeze(1).repeat(1, disp_ws, nnj, 1, 1).contiguous() # bsz x (ws - 1) x nnj x nnb x (3 + 3 + 1 + 1) base_pts_normals_e_in_feats = torch.cat( # along normals; # vt normals # [base_pts_disp_exp, base_normals_disp_exp, obj_pts_disp_exp, vel_obj_pts_to_hand_pts, disp_dist, e_disp_rel_to_base_along_normals.unsqueeze(-1), e_disp_rel_to_baes_vt_normals.unsqueeze(-1)], dim=-1 ) base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.permute(0, 1, 3, 2, 4).contiguous() # bsz x (ws - 1) x nnb x (nnj x (xxx feats_dim)) base_pts_normals_e_in_feats = base_pts_normals_e_in_feats.view(bsz, disp_ws, nnb, -1).contiguous() ## input process ## base_jts_e_feats = self.input_process_e(base_pts_normals_e_in_feats) base_jts_e_feats = self.sequence_pos_encoder_e(base_jts_e_feats) ## seq transformation for e ## # base_jts_e_feats_mean = self.seqTransEncoder_e(base_jts_e_feats) ## mean, mdm_ours ## # print(f"base_jts_e_feats: {base_jts_e_feats.size()}") ### Embed physicss quantities ### #### base_jts_e_feats, base_jts_e_feats_mean #### # ## us basejtsefeats for denoising directly ## base_jts_e_feats = base_jts_e_feats_mean if not self.args.without_dec_pos_emb: ## use positional encoding ## # rel_base_pts_outputs = self.sequence_pos_denoising_encoder(rel_base_pts_outputs) base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_time_emb = base_jts_e_time_emb.repeat(base_jts_e_feats.size(0), 1, 1).contiguous() base_jts_e_feats = base_jts_e_feats + base_jts_e_time_emb if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2) else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats) else: #### Decode e_along_normals and e_vt_normals #### base_jts_e_time_emb = self.embed_timestep_e(timesteps) base_jts_e_feats = torch.cat( [base_jts_e_time_emb, base_jts_e_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## transformer encoder ## base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) base_jts_e_feats = base_jts_e_feats.permute(1, 0, 2)[1:] else: base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] # ### sequence latents ### # if self.args.train_enc: # trian enc for seq latents ### # base_jts_e_feats = input_latent_feats["base_jts_e_feats_enc"] # base_jts_e_feats = self.sequence_pos_denoising_encoder_e(base_jts_e_feats) #### Decode e_along_normals and e_vt_normals #### ### bae jts e embeddings feats ### # base_jts_e_time_emb = self.embed_timestep_e(timesteps) # base_jts_e_feats = torch.cat( # [base_jts_e_time_emb, base_jts_e_feats], dim=0 # ) # base_jts_e_feats = self.basejtse_denoising_seqTransEncoder(base_jts_e_feats)[1:] ##### output_process_e -> output energies ##### base_jts_e_output = self.output_process_e(base_jts_e_feats, x) # dec_e_along_normals: bsz x (ws - 1) x nnb x nnj # dec_e_along_normals = base_jts_e_output['dec_e_along_normals'] dec_e_vt_normals = base_jts_e_output['dec_e_vt_normals'] dec_d = base_jts_e_output['dec_d'] rel_vel_dec = base_jts_e_output['rel_vel_dec'] dec_e_along_normals = dec_e_along_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb dec_e_vt_normals = dec_e_vt_normals.contiguous().permute(0, 1, 3, 2).contiguous() # bsz x (ws - 1) x nnj x nnb #### Decode e_along_normals and e_vt_normals #### diff_basejtse_dict = { 'dec_e_along_normals': dec_e_along_normals, 'dec_e_vt_normals': dec_e_vt_normals, 'base_jts_e_feats': base_jts_e_feats, 'dec_d': dec_d, 'rel_vel_dec': rel_vel_dec, } # rt_dict['base_jts_e_feats'] = base_jts_e_feats # rt_dict['base_jts_e_feats_mean'] = base_jts_e_feats_mean # rt_dict['base_jts_e_feats_logvar'] = base_jts_e_feats_logvar # log_var # else: diff_basejtse_dict = {} if self.diff_jts: # base_pts_normal ### InputProcess ### pert_rhand_joints_trans = pert_rhand_joints.permute(0, 2, 3, 1).contiguous() # bsz x nnj x 3 x ws # rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints_trans) # [seqlen, bs, d] ### InputProcessObjBase ### # rhand_joints_feats = self.joint_sequence_input_process(pert_rhand_joints) ### === Encode input joint sequences === ### # bs, njoints, nfeats, nframes = x.shape # rhand_joints_emb = self.joint_sequence_embed_timestep(timesteps) # [1, bs, d] # if self.arch == 'trans_enc': xseq = rhand_joints_feats # [seqlen+1, bs, d] xseq = self.joint_sequence_pos_encoder(xseq) # [seqlen+1, bs, d] joint_seq_output_mean = self.joint_sequence_seqTransEncoder(xseq) # [1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d] ### calculate logvar, mean, and feats ### joint_seq_output_logvar = self.joint_sequence_logvar_seqTransEncoder(xseq) # # logvar_seqTransEncoder_e, logvar_seqTransEncoder, joint_sequence_logvar_seqTransEncoder # # joint_seq_output_var = torch.exp(joint_seq_output_logvar) # seq x bs x d --> encodeing and decoding ## base_jts_e_feats: seqlen x bs x d --> val latents ## joint_seq_output = self.reparameterization(joint_seq_output_mean, joint_seq_output_var) rt_dict['joint_seq_output'] = joint_seq_output # rt_dict['joint_seq_output'] = joint_seq_output_mean rt_dict['joint_seq_output_mean'] = joint_seq_output_mean rt_dict['joint_seq_output_logvar'] = joint_seq_output_logvar if self.args.diff_realbasejtsrel_to_joints: # nframes x nnbase x nnjts x (base pts + base normals + 3) 2) point feature for each point; point feature for; condition on the noisy input for the denoised information # real_basejtsrel_to_joints_input_process, real_basejtsrel_to_joints_sequence_pos_encoder, real_basejtsrel_to_joints_seqTransEncoder # real_basejtsrel_to_joints_embed_timestep, real_basejtsrel_to_joints_sequence_pos_denoising_encoder, real_basejtsrel_to_joints_denoising_seqTransEncoder, real_basejtsrel_to_joints_output_process bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_joints_for_jts_pred'].size()[:4] normed_base_pts = x['normed_base_pts'] base_normals = x['base_normals'] pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_joints_for_jts_pred'] # bsz x nf x nnj x nnb x 3 ## use_abs_jts_pos --> obj jts pos for the encodingj if self.args.use_abs_jts_pos: ## bsz x nf x nnj x nnb x 3 ## ---> abs jts pos ## pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # use_abs_jts_for_encoding, real_basejtsrel_to_joints_input_process if self.args.use_abs_jts_for_encoding: if not self.args.use_abs_jts_pos: pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 abs_jts = pert_rel_base_pts_to_rhand_joints[..., 0, :] abs_jts = abs_jts.permute(0, 2, 3, 1).contiguous() obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(abs_jts) elif self.args.use_abs_jts_for_encoding_obj_base: if not self.args.use_abs_jts_pos: pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints + normed_base_pts.unsqueeze(1).unsqueeze(1) # pert_rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] # obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) # [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1)[:, :, :, 0:1, :], pert_rel_base_pts_to_rhand_joints], dim=-1 # ) obj_base_in_feats = pert_rel_base_pts_to_rhand_joints # --> tnrasform the input feature dim to 21 * 3 here for encoding # obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, 1, -1).contiguous() # obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # else: if self.args.use_objbase_v2: # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # obj_base_encoded_feats = self.real_basejtsrel_to_joints_input_process(obj_base_in_feats) # nf x bsz x feat_dim # # ### real_basejtsrel_to_joints_input_process --> real_basejtsrel_to_joints_input_process --> for the joints and input process ### # obj_base_encoded_feats # obj_base_encoded_feats obj_base_pts_feats_pos_embedding = self.real_basejtsrel_to_joints_sequence_pos_encoder(obj_base_encoded_feats) obj_base_pts_feats = self.real_basejtsrel_to_joints_seqTransEncoder(obj_base_pts_feats_pos_embedding) if self.args.use_sigmoid: obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq des real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_to_joints_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_to_joints_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] joints_offset_output = self.real_basejtsrel_to_joints_output_process(real_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) if self.args.diff_basejtsrel: diff_basejtsrel_to_joints_dict = { 'joints_offset_output_from_rel': joints_offset_output } else: diff_basejtsrel_to_joints_dict = { 'joints_offset_output': joints_offset_output } else: diff_basejtsrel_to_joints_dict = {} if self.diff_realbasejtsrel: # real_dec_basejtsrel # real_basejtsrel_input_process, real_basejtsrel_sequence_pos_encoder, real_basejtsrel_seqTransEncoder, real_basejtsrel_embed_timestep, real_basejtsrel_sequence_pos_denoising_encoder, real_basejtsrel_denoising_seqTransEncoder bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] normed_base_pts = x['normed_base_pts'] base_normals = x['base_normals'] pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 if self.args.use_objbase_v2: # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.use_objbase_v4: # use_objbase_v4: # use_objbase_out_v4 exp_normed_base_pts = normed_base_pts.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() exp_base_normals = base_normals.unsqueeze(1).unsqueeze(2).repeat(1, nf, nnj, 1, 1).contiguous() obj_base_in_feats = torch.cat( [pert_rel_base_pts_to_rhand_joints, exp_normed_base_pts, exp_base_normals], dim=-1 # bsz x nf x nnj x nnb x (3 + 3 + 3) # -> exp_base_normals ) obj_base_in_feats = obj_base_in_feats.view(bsz, nf, nnj, -1).contiguous() elif self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() if self.args.v5_in_not_base: obj_base_in_feats = torch.cat( [ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.v5_in_not_base_pos: obj_base_in_feats = torch.cat( [base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.use_objbase_v6 or self.args.use_objbase_v7: pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # # obj_base_in_feats = torch.cat( # bsz x nf x nnj x nnb x (3 + 3 + 3) # [normed_base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nf, nnj, 1, 1), pert_rel_base_pts_to_rhand_joints], dim=-1 # ) # print(f"obj_base_in_feats: {obj_base_in_feats.size()}, bsz: {bsz}, nf: {nf}, nnj: {nnj}, nnb: {nnb}") # obj_base_in_feats = obj_base_in_feats.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # if self.args.use_objbase_v6: normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1) # and repeat for the base pts # obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats, normed_base_pts_exp) else: obj_base_encoded_feats = self.real_basejtsrel_input_process(obj_base_in_feats) # nf x bsz x feat_dim # nf x bsz x nnbasepts x feats_dim # # obj_base_encoded_feats obj_base_pts_feats_pos_embedding = self.real_basejtsrel_sequence_pos_encoder(obj_base_encoded_feats) obj_base_pts_feats = self.real_basejtsrel_seqTransEncoder(obj_base_pts_feats_pos_embedding) if self.args.use_sigmoid: obj_base_pts_feats = (torch.sigmoid(obj_base_pts_feats) - 0.5) * 2. if self.args.train_enc: # basejtsrel_seq # bsz, nframes, nnb, nnj, 3 --> # real_dec_basejtsrel = self.real_basejtsrel_output_process( obj_base_pts_feats, x ) real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] if self.args.use_objbase_out_v3: real_dec_basejtsrel = real_dec_basejtsrel else: real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 else: if self.args.deep_fuse_timeemb: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### # print(f"timesteps: {timesteps.size()}, obj_base_pts_feats: {obj_base_pts_feats.size()}") if self.args.use_objbase_v5: cur_timesteps = timesteps.unsqueeze(1).repeat(1, nnb).view(-1) else: cur_timesteps = timesteps real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(cur_timesteps) real_basejtsrel_time_emb = real_basejtsrel_time_emb.repeat(obj_base_pts_feats.size(0), 1, 1).contiguous() real_basejtsrel_seq_latents = obj_base_pts_feats + real_basejtsrel_time_emb if self.args.use_ours_transformer_enc: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2) else: # seq des real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### real_basejtsrel_time_emb = self.real_basejtsrel_embed_timestep(timesteps) real_basejtsrel_seq_latents = torch.cat( [real_basejtsrel_time_emb, obj_base_pts_feats], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(real_basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) real_basejtsrel_seq_latents = real_basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: real_basejtsrel_seq_latents = self.real_basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] # basejtsrel_seq # bsz, nframes, nnb, nnj, 3 --> # if self.args.use_jts_pert_realbasejtsrel: joints_offset_output = self.real_basejtsrel_output_process(real_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) # bsz x nf x nnj x 3 real_dec_basejtsrel = joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, nnb, 1) # real_dec_basejtsrel = joints_offset_output else: real_dec_basejtsrel = self.real_basejtsrel_output_process( real_basejtsrel_seq_latents, x ) # real_dec_basejtsrel -> decoded realtive positions # real_dec_basejtsrel = real_dec_basejtsrel['dec_rel'] if self.args.use_objbase_out_v3 or self.args.use_objbase_out_v4 or self.args.use_objbase_out_v5: real_dec_basejtsrel = real_dec_basejtsrel else: real_dec_basejtsrel = real_dec_basejtsrel.permute(0, 1, 3, 2, 4).contiguous() # bsz x nf x nnj x nnb x 3 diff_realbasejtsrel_out_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, 'obj_base_pts_feats': obj_base_pts_feats, } else: diff_realbasejtsrel_out_dict = {} if self.diff_basejtsrel: bsz, nf, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].size()[:4] # cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # normed_base_pts = x['normed_base_pts'] base_normals = x['base_normals'] pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # bsz x nf x nnj x nnb x 3 # conditional strategies and representations for denoising # # conditional strategies # [finetune_with_cond_rel, finetune_with_cond_jtsobj] # # cond_real_basejtsrel_input_process, cond_real_basejtsrel_pos_encoder, cond_real_basejtsrel_seqTransEncoderLayer, cond_trans_linear_layer # # finetune_with_cond_jtsobj : cond_joints_offset_input_process <- joints_offset_input_process; cond_sequence_pos_encoder <- sequence_pos_encoder; cond_seqTransEncoder <- seqTransEncoder if self.args.finetune_with_cond_rel: # finetune with cond rel # finetune with cond # finetune with if self.args.use_objbase_v5: # use_objbase_v5, use_objbase_out_v5 pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() if self.args.v5_in_not_base: obj_base_in_feats = torch.cat( [ pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) elif self.args.v5_in_not_base_pos: obj_base_in_feats = torch.cat( [base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: obj_base_in_feats = torch.cat( [normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1), base_normals.unsqueeze(1).repeat(1, nf, 1, 1), pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: raise ValueError(f"Must use objbase_v5 currently, others have not been implemented yet.") # print(f'obj_base_in_feats: {obj_base_in_feats.size()}') obj_base_encoded_feats = self.cond_real_basejtsrel_input_process(obj_base_in_feats) obj_base_pts_feats_pos_embedding = self.cond_real_basejtsrel_pos_encoder(obj_base_encoded_feats) # print(f"obj_base_pts_feats_pos_embedding: {obj_base_pts_feats_pos_embedding.size()}") obj_base_pts_feats = self.cond_real_basejtsrel_seqTransEncoderLayer(obj_base_pts_feats_pos_embedding) # seq x bsz x latent_dim elif self.args.finetune_with_cond_jtsobj: # cond_obj_input_layer, cond_obj_trans_layer, cond_joints_offset_input_process, cond_sequence_pos_encoder, cond_seqTransEncoder # cond_joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 cond_joints_offset_sequence = cond_joints_offset_sequence.permute(0, 2, 3, 1).contiguous() cond_joints_offset_feats = self.cond_joints_offset_input_process(cond_joints_offset_sequence) # nf x bsz x dim pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() # if self.cond_obj_input_feats == 3: normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() elif self.cond_obj_input_feats == 6: normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() normed_base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() normed_base_pts_exp = torch.cat( [normed_base_pts_exp, normed_base_normals_exp], dim=-1 ## bsz x nf x nn_base_pts x (3 + 3) ) elif self.cond_obj_input_feats == (6 + 21 * 3): # 63 + 6 -> 69 normed_base_pts_exp = normed_base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() normed_base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() pert_rel_base_pts_to_rhand_joints_exp = pert_rel_base_pts_to_rhand_joints.transpose(-2, -3).contiguous().view(bsz, nf, nnb, -1).contiguous() normed_base_pts_exp = torch.cat( [normed_base_pts_exp, normed_base_normals_exp, pert_rel_base_pts_to_rhand_joints_exp], dim=-1 ) else: # cond_obj_input_feats --> cond_obj_input_feats # finetune_cond_obj_feats_dim raise ValueError(f"Unrecognized cond_obj_input_feats: {self.cond_obj_input_feats}") cond_obj_feats = self.cond_obj_input_layer(normed_base_pts_exp) # nf x bsz x feats_dim # cond_obj_feats = self.cond_obj_trans_layer(cond_obj_feats) cond_jtsobj_feats = cond_joints_offset_feats + cond_obj_feats cond_jtsobj_feats = self.cond_sequence_pos_encoder(cond_jtsobj_feats) obj_base_pts_feats = self.cond_seqTransEncoder(cond_jtsobj_feats) else: raise ValueError(f"Either ``finetune_with_cond_rel'' or ``finetune_with_cond_jtsobj'' should be activated! ") obj_base_pts_feats = self.cond_trans_linear_layer(obj_base_pts_feats) # print(f"obj_base_pts_feats: {obj_base_pts_feats.size()}") ### ==== Encoding joints ==== ### joints_offset_sequence = x['pert_joints_offset_sequence'] # bsz x nf x nnj x 3 joints_offset_sequence = joints_offset_sequence.permute(0, 2, 3, 1).contiguous() joints_offset_feats = self.joints_offset_input_process(joints_offset_sequence) # nf x bsz x dim rel_base_pts_feats_pos_embedding = self.sequence_pos_encoder(joints_offset_feats) rel_base_pts_feats = rel_base_pts_feats_pos_embedding # seqTransEncoder, logvar_seqTransEncoder # rel_base_pts_outputs_mean = self.seqTransEncoder(rel_base_pts_feats) ### ==== Encoding joints ==== ### ### ==== fuse conditional embeddings with joints embeddigns ==== ### rel_base_pts_outputs_mean = rel_base_pts_outputs_mean + obj_base_pts_feats ### ==== fuse conditional embeddings with joints embeddigns ==== ### # print(f"rel_base_pts_outputs_mean: {rel_base_pts_outputs_mean.size()}") ### ==== denoise latent features ==== ### if not self.args.without_dec_pos_emb: # without dec pos embedding # avg_jts_inputs = rel_base_pts_outputs_mean[0:1, ...] other_rel_base_pts_outputs = rel_base_pts_outputs_mean # other_rel_base_pts_outputs = self.sequence_pos_denoising_encoder(other_rel_base_pts_outputs) rel_base_pts_outputs_mean = other_rel_base_pts_outputs if self.args.deep_fuse_timeemb: basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_time_emb = basejtsrel_time_emb.repeat(rel_base_pts_outputs_mean.size(0), 1, 1).contiguous() basejtsrel_seq_latents = rel_base_pts_outputs_mean + basejtsrel_time_emb if self.args.use_ours_transformer_enc: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2) else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents) else: ## denoising process ### ## GET joints seq time embeddings ### ### embed time stamps ### basejtsrel_time_emb = self.basejtsrel_denoising_embed_timestep(timesteps) basejtsrel_seq_latents = torch.cat( [basejtsrel_time_emb, rel_base_pts_outputs_mean], dim=0 ) if self.args.use_ours_transformer_enc: ## mdm ours ## basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents.permute(1, 0, 2), set_attn_to_none=self.args.set_attn_to_none) basejtsrel_seq_latents = basejtsrel_seq_latents.permute(1, 0, 2)[1:] else: basejtsrel_seq_latents = self.basejtsrel_denoising_seqTransEncoder(basejtsrel_seq_latents)[1:] ### ==== denoise latent features ==== ### other_basejtsrel_seq_latents = basejtsrel_seq_latents # [1:, ...] joints_offset_output = self.joint_offset_output_process(other_basejtsrel_seq_latents) joints_offset_output = joints_offset_output.permute(0, 3, 1, 2) diff_basejtsrel_dict = { 'joints_offset_output': joints_offset_output } else: diff_basejtsrel_dict = {} rt_dict = {} rt_dict.update(diff_basejtsrel_dict) rt_dict.update(diff_basejtse_dict) ### rt_dict and diff_basejtse rt_dict.update(diff_basejtsrel_to_joints_dict) rt_dict.update(diff_realbasejtsrel_out_dict) ### diff return rt_dict def _apply(self, fn): super()._apply(fn) # self.rot2xyz.smpl_model._apply(fn) def train(self, *args, **kwargs): super().train(*args, **kwargs) # self.rot2xyz.smpl_model.train(*args, **kwargs) ### LayerNorm layer ### class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.dim = dim self.eps = eps def forward(self, x): x = (x - x.mean(dim=self.dim, keepdim=True)) / torch.sqrt(x.var(dim=self.dim, keepdim=True)+self.eps) return x class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): # not used in the final model x = x + self.pe[:x.shape[0], :] return self.dropout(x) class TimestepEmbedder(nn.Module): def __init__(self, latent_dim, sequence_pos_encoder): super().__init__() self.latent_dim = latent_dim self.sequence_pos_encoder = sequence_pos_encoder time_embed_dim = self.latent_dim self.time_embed = nn.Sequential( nn.Linear(self.latent_dim, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim), ) def forward(self, timesteps): return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) ## hand sequence; object shape # InputProcessObjBase(self, data_rep, input_feats, latent_dim) class InputProcessObjBaseV5(nn.Module): # inputobjbase def __init__(self, data_rep, input_feats, latent_dim, layernorm=True, without_glb=False, only_with_glb=False): super().__init__() self.data_rep = data_rep self.input_feats = input_feats # 21 * 3 + 3 + 3 --> for each joint + 3 pos + 3 normals # self.latent_dim = latent_dim self.pts_feats_encoding_net = nn.Sequential( # nnb --> 21 nn.Linear(input_feats, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), ) self.glb_feats_encoding_net = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), ) self.without_glb = without_glb self.only_with_glb = only_with_glb if self.without_glb: self.pts_glb_feats_encoding_net = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), ) else: self.pts_glb_feats_encoding_net = nn.Sequential( nn.Linear(self.latent_dim * 2, self.latent_dim), nn.ReLU(), nn.Linear(self.latent_dim, self.latent_dim), ) # self.embedding_pn_blk = nn.Sequential( # nnb --> 21 # nn.Linear(input_feats, self.latent_dim), nn.ReLU(), # nn.Linear(self.latent_dim, self.latent_dim), # ) def forward(self, x): # decode relative positions and others ## # bs, nframes, njoints, nfeats = x.shape # # bsz x nf x nnj x (3 + nnb x (3 + 3)) # bsz x nf x nnb x (latent_dim) # x: bsz x nf x nnb x (3 + 3 + 21 * 3) # x.size() bsz, nf, nnb = x.size()[:3] if self.only_with_glb: x_pts_emb = self.pts_feats_encoding_net( # input # noisy --- too noisy # x ) x_glb_emb, _ = torch.max(x_pts_emb, dim=2, keepdim=True) x_glb_emb = self.glb_feats_encoding_net(x_glb_emb) # bsz x nf x latent_dim x_glb_emb = x_glb_emb.squeeze(-2) x_pts_emb = x_glb_emb.permute(1, 0, 2).contiguous() else: x_pts_emb = self.pts_feats_encoding_net( # input # noisy --- too noisy # x ) if not self.without_glb: x_glb_emb, _ = torch.max(x_pts_emb, dim=2, keepdim=True) x_glb_emb = self.glb_feats_encoding_net(x_glb_emb) # bsz x nf x 1 x latnet_dim x_pts_emb = torch.cat( # 1 [x_pts_emb, x_glb_emb.repeat(1, 1, nnb, 1)], dim=-1 ) x_pts_emb = self.pts_glb_feats_encoding_net(x_pts_emb) # bsz x nf x nn_base_pts x latent_dim # x_pts_emb = x_pts_emb.permute(1, 0, 2, 3) # nf x bsz x nn_base_pts x latent_dim # x_pts_emb = x_pts_emb.contiguous().view(x_pts_emb.size(0), bsz * nnb, -1).contiguous() # nf x (bsz x nn_base_pts) x latent_dim # return x_pts_emb class OutputProcessObjBaseRawV5(nn.Module): def __init__(self, data_rep, latent_dim, not_cond_base=False, out_objbase_v5_bundle_out=False, v5_out_not_cond_base=False, nn_keypoints=21): super().__init__() self.data_rep = data_rep # self.input_feats = input_feats self.latent_dim = latent_dim # self.njoints = njoints self.not_cond_base = not_cond_base ## not cond base ## # self.nfeats = nfeats # dec cond on latent code and base pts, base normals # self.v5_out_not_cond_base = v5_out_not_cond_base if self.not_cond_base: self.rel_dec_cond_dim = self.latent_dim self.dist_dec_cond_dim = self.latent_dim else: self.rel_dec_cond_dim = self.latent_dim + 3 + 3 + 3 self.dist_dec_cond_dim = self.latent_dim + 3 + 3 # self.use_anchors = use_anchors self.nn_keypoints = nn_keypoints # if self.use_anchors: # self.nn_keypoints = # self.rel_dec_blk = nn.Sequential( # nn.Linear(self.rel_dec_cond_dim, 3,), # ) self.out_objbase_v5_bundle_out = out_objbase_v5_bundle_out if self.out_objbase_v5_bundle_out: if self.v5_out_not_cond_base: self.rel_dec_blk = nn.Sequential( nn.Linear(self.latent_dim, self.latent_dim // 2), nn.ReLU(), nn.Linear(self.latent_dim // 2, self.nn_keypoints * 3), ) else: self.rel_dec_blk = nn.Sequential( nn.Linear(self.latent_dim + 3 + 3, self.latent_dim // 2), nn.ReLU(), nn.Linear(self.latent_dim // 2, self.nn_keypoints * 3), ) else: self.rel_dec_blk = nn.Sequential( nn.Linear(self.rel_dec_cond_dim, 3,), ) # self.rel_dec_blk = nn.Linear( # rel_dec_blk -> output relative positions # # self.rel_dec_cond_dim, 3 * 21 # ) self.dist_dec_cond_dim = self.latent_dim + 3 + 3 self.dist_dec_blk = nn.Linear( # dist_dec_blk -> output relative distances # self.dist_dec_cond_dim, 1 * self.nn_keypoints ) # self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) # if self.data_rep == 'rot_vel': # self.velFinal = nn.Linear(self.latent_dim, self.input_feats) def forward(self, output, x): # output # nframes, bs, d = output.shape # bsz, nframes, nnj, nnb = x['rel_base_pts_to_rhand_joints'].shape[:4] # pert_rel_base_pts_to_rhand_joints bsz, nframes, nnj, nnb = x['pert_rel_base_pts_to_rhand_joints'].shape[:4] # bsz x nf x nnj x nnb x 3 # nf x nnb x 3 --> noisy input for denoised values # # forward the samole # base_pts, base_normals, # # base_pts = x['base_pts'] # bsz x nnb x 3 base_pts = x['normed_base_pts'] # bsz x nnb x 3 base_normals = x['base_normals'] # bsz x nnb x 3 # rel_base_pts_to_rhand_joints = x['rel_base_pts_to_rhand_joints'] # bsz x ws x nnj x nnb x 3 # dist_base_pts_to_rhand_joints = x['dist_base_pts_to_rhand_joints'] # bsz x ws x nnj x nnb ## # output: bsz x nf x nnj x latent_dim output = output.view(nframes, bsz, nnb, -1) # nframes x bsz x nnb x latent_dim output = output.permute(1, 0, 2, 3) # bsz x nnf x nnb x latent_dim ### for the output_dim # if self.out_objbase_v5_bundle_out: if self.v5_out_not_cond_base: output_exp = output else: # otuptu_exp for rel_dec_blk base_pts_exp = base_pts.unsqueeze(1).repeat(1, nframes, 1, 1) base_normals_exp = base_normals.unsqueeze(1).repeat(1, nframes, 1, 1) output_exp = torch.cat( # with input noisy data # ############### denoised latents for each base pts ### [output, base_pts_exp, base_normals_exp], dim=-1 ) dec_rel = self.rel_dec_blk(output_exp) dec_rel = dec_rel.view(bsz, nframes, nnb, nnj, 3).permute(0, 1, 3, 2, 4).contiguous() else: # output = output.permute(1, 0, 2) # output = output.view(bsz, nframes, nnj, -1).contiguous() # bsz x nf x nnj x (decoded_latent_dim) # output = output.unsqueeze(2).repeat(1, 1, nnj, 1, 1).contiguous() # bsz x nnframes x d # # # output = output.permute(1, 0, 2).contiguous().unsqueeze(2).unsqueeze(2).repeat(1, 1, nnj, nnb, 1).contiguous() base_pts_exp = base_pts.unsqueeze(1).unsqueeze(1).repeat(1, nframes, nnj, 1, 1) base_normals_exp = base_normals.unsqueeze(1).unsqueeze(1).repeat(1, nframes, nnj, 1, 1) # bsz x nnframes x nnb x (d + 3 + 3) # --> base normals ## # if self.not_cond_base: # output_exp = output # else: output_exp = torch.cat( # with input noisy data [output, base_pts_exp, base_normals_exp, x['pert_rel_base_pts_to_rhand_joints']], dim=-1 ) dec_rel = self.rel_dec_blk(output_exp) # bsz x nnframes x nnb x (21 * 3) --> decoded relative positions # dec_rel = dec_rel.contiguous().view(bsz, nframes, nnj, nnb, 3).contiguous() # bsz x nnframes x nnb x nnj x 3 # # decoded rel, decoded distances # out = { 'dec_rel': dec_rel, # 'dec_dist': dec_dist.squeeze(-1), } return out ## output