import torch.nn as nn from models.encdec import Encoder, Decoder from models.quantize_cnn import QuantizeEMAReset, Quantizer, QuantizeEMA, QuantizeReset from models.t2m_trans import Decoder_Transformer, Encoder_Transformer from exit.utils import generate_src_mask import torch from utils.humanml_utils import HML_UPPER_BODY_MASK, HML_LOWER_BODY_MASK, UPPER_JOINT_Y_MASK class VQVAE_SEP(nn.Module): def __init__(self, args, nb_code=512, code_dim=512, output_emb_width=512, down_t=3, stride_t=2, width=512, depth=3, dilation_growth_rate=3, activation='relu', norm=None, moment=None, sep_decoder=False): super().__init__() if args.dataname == 'kit': self.nb_joints = 21 output_dim = 251 upper_dim = 120 lower_dim = 131 else: self.nb_joints = 22 output_dim = 263 upper_dim = 156 lower_dim = 107 self.code_dim = code_dim if moment is not None: self.moment = moment self.register_buffer('mean_upper', torch.tensor([0.1216, 0.2488, 0.2967, 0.5027, 0.4053, 0.4100, 0.5703, 0.4030, 0.4078, 0.1994, 0.1992, 0.0661, 0.0639], dtype=torch.float32)) self.register_buffer('std_upper', torch.tensor([0.0164, 0.0412, 0.0523, 0.0864, 0.0695, 0.0703, 0.1108, 0.0853, 0.0847, 0.1289, 0.1291, 0.2463, 0.2484], dtype=torch.float32)) # self.quantizer = QuantizeEMAReset(nb_code, code_dim, args) # self.encoder = Encoder(output_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.sep_decoder = sep_decoder if self.sep_decoder: self.decoder_upper = Decoder(upper_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.decoder_lower = Decoder(lower_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) else: self.decoder = Decoder(output_dim, code_dim, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.num_code = nb_code self.encoder_upper = Encoder(upper_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.encoder_lower = Encoder(lower_dim, int(code_dim/2), down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm) self.quantizer_upper = QuantizeEMAReset(nb_code, int(code_dim/2), args) self.quantizer_lower = QuantizeEMAReset(nb_code, int(code_dim/2), args) def rand_emb_idx(self, x_quantized, quantizer, idx_noise): # x_quantized = x_quantized.detach() x_quantized = x_quantized.permute(0,2,1) mask = torch.bernoulli(idx_noise * torch.ones((*x_quantized.shape[:2], 1), device=x_quantized.device)) r_indices = torch.randint(int(self.num_code/2), x_quantized.shape[:2], device=x_quantized.device) r_emb = quantizer.dequantize(r_indices) x_quantized = mask * r_emb + (1-mask) * x_quantized x_quantized = x_quantized.permute(0,2,1) return x_quantized def normalize(self, data): return (data - self.moment['mean']) / self.moment['std'] def denormalize(self, data): return data * self.moment['std'] + self.moment['mean'] def normalize_upper(self, data): return (data - self.mean_upper) / self.std_upper def denormalize_upper(self, data): return data * self.std_upper + self.mean_upper def shift_upper_down(self, data): data = data.clone() data = self.denormalize(data) shift_y = data[..., 3:4].clone() data[..., UPPER_JOINT_Y_MASK] -= shift_y _data = data.clone() data = self.normalize(data) data[..., UPPER_JOINT_Y_MASK] = self.normalize_upper(_data[..., UPPER_JOINT_Y_MASK]) return data def shift_upper_up(self, data): _data = data.clone() data = self.denormalize(data) data[..., UPPER_JOINT_Y_MASK] = self.denormalize_upper(_data[..., UPPER_JOINT_Y_MASK]) shift_y = data[..., 3:4].clone() data[..., UPPER_JOINT_Y_MASK] += shift_y data = self.normalize(data) return data def forward(self, x, *args, type='full', **kwargs): '''type=[full, encode, decode]''' if type=='full': x = x.float() x = self.shift_upper_down(x) upper_emb = x[..., HML_UPPER_BODY_MASK] lower_emb = x[..., HML_LOWER_BODY_MASK] upper_emb = self.preprocess(upper_emb) upper_emb = self.encoder_upper(upper_emb) upper_emb, loss_upper, perplexity = self.quantizer_upper(upper_emb) lower_emb = self.preprocess(lower_emb) lower_emb = self.encoder_lower(lower_emb) lower_emb, loss_lower, perplexity = self.quantizer_lower(lower_emb) loss = loss_upper + loss_lower if 'idx_noise' in kwargs and kwargs['idx_noise'] > 0: upper_emb = self.rand_emb_idx(upper_emb, self.quantizer_upper, kwargs['idx_noise']) lower_emb = self.rand_emb_idx(lower_emb, self.quantizer_lower, kwargs['idx_noise']) # x_in = self.preprocess(x) # x_encoder = self.encoder(x_in) # ## quantization # x_quantized, loss, perplexity = self.quantizer(x_encoder) ## decoder if self.sep_decoder: x_decoder_upper = self.decoder_upper(upper_emb) x_decoder_upper = self.postprocess(x_decoder_upper) x_decoder_lower = self.decoder_lower(lower_emb) x_decoder_lower = self.postprocess(x_decoder_lower) x_out = merge_upper_lower(x_decoder_upper, x_decoder_lower) x_out = self.shift_upper_up(x_out) else: x_quantized = torch.cat([upper_emb, lower_emb], dim=1) x_decoder = self.decoder(x_quantized) x_out = self.postprocess(x_decoder) return x_out, loss, perplexity elif type=='encode': N, T, _ = x.shape x = self.shift_upper_down(x) upper_emb = x[..., HML_UPPER_BODY_MASK] upper_emb = self.preprocess(upper_emb) upper_emb = self.encoder_upper(upper_emb) upper_emb = self.postprocess(upper_emb) upper_emb = upper_emb.reshape(-1, upper_emb.shape[-1]) upper_code_idx = self.quantizer_upper.quantize(upper_emb) upper_code_idx = upper_code_idx.view(N, -1) lower_emb = x[..., HML_LOWER_BODY_MASK] lower_emb = self.preprocess(lower_emb) lower_emb = self.encoder_lower(lower_emb) lower_emb = self.postprocess(lower_emb) lower_emb = lower_emb.reshape(-1, lower_emb.shape[-1]) lower_code_idx = self.quantizer_lower.quantize(lower_emb) lower_code_idx = lower_code_idx.view(N, -1) code_idx = torch.cat([upper_code_idx.unsqueeze(-1), lower_code_idx.unsqueeze(-1)], dim=-1) return code_idx elif type=='decode': if self.sep_decoder: x_d_upper = self.quantizer_upper.dequantize(x[..., 0]) x_d_upper = x_d_upper.permute(0, 2, 1).contiguous() x_d_upper = self.decoder_upper(x_d_upper) x_d_upper = self.postprocess(x_d_upper) x_d_lower = self.quantizer_lower.dequantize(x[..., 1]) x_d_lower = x_d_lower.permute(0, 2, 1).contiguous() x_d_lower = self.decoder_lower(x_d_lower) x_d_lower = self.postprocess(x_d_lower) x_out = merge_upper_lower(x_d_upper, x_d_lower) x_out = self.shift_upper_up(x_out) return x_out else: x_d_upper = self.quantizer_upper.dequantize(x[..., 0]) x_d_lower = self.quantizer_lower.dequantize(x[..., 1]) x_d = torch.cat([x_d_upper, x_d_lower], dim=-1) x_d = x_d.permute(0, 2, 1).contiguous() x_decoder = self.decoder(x_d) x_out = self.postprocess(x_decoder) return x_out def preprocess(self, x): # (bs, T, Jx3) -> (bs, Jx3, T) x = x.permute(0,2,1).float() return x def postprocess(self, x): # (bs, Jx3, T) -> (bs, T, Jx3) x = x.permute(0,2,1) return x def merge_upper_lower(upper_emb, lower_emb): motion = torch.empty(*upper_emb.shape[:2], 263).to(upper_emb.device) motion[..., HML_UPPER_BODY_MASK] = upper_emb motion[..., HML_LOWER_BODY_MASK] = lower_emb return motion def upper_lower_sep(motion, joints_num): # root _root = motion[..., :4] # root # position start_indx = 1 + 2 + 1 end_indx = start_indx + (joints_num - 1) * 3 positions = motion[..., start_indx:end_indx] positions = positions.view(*motion.shape[:2], (joints_num - 1), 3) # 6drot start_indx = end_indx end_indx = start_indx + (joints_num - 1) * 6 _6d_rot = motion[..., start_indx:end_indx] _6d_rot = _6d_rot.view(*motion.shape[:2], (joints_num - 1), 6) # joint_velo start_indx = end_indx end_indx = start_indx + joints_num * 3 joint_velo = motion[..., start_indx:end_indx] joint_velo = joint_velo.view(*motion.shape[:2], joints_num, 3) # foot_contact foot_contact = motion[..., end_indx:] ################################################################################################ #### Lower Body if joints_num == 22: lower_body = torch.tensor([0,1,2,4,5,7,8,10,11]) else: lower_body = torch.tensor([0, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) lower_body_exclude_root = lower_body[1:] - 1 LOW_positions = positions[:,:, lower_body_exclude_root].view(*motion.shape[:2], -1) LOW_6d_rot = _6d_rot[:,:, lower_body_exclude_root].view(*motion.shape[:2], -1) LOW_joint_velo = joint_velo[:,:, lower_body].view(*motion.shape[:2], -1) lower_emb = torch.cat([_root, LOW_positions, LOW_6d_rot, LOW_joint_velo, foot_contact], dim=-1) #### Upper Body if joints_num == 22: upper_body = torch.tensor([3,6,9,12,13,14,15,16,17,18,19,20,21]) else: upper_body = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) upper_body_exclude_root = upper_body - 1 UP_positions = positions[:,:, upper_body_exclude_root].view(*motion.shape[:2], -1) UP_6d_rot = _6d_rot[:,:, upper_body_exclude_root].view(*motion.shape[:2], -1) UP_joint_velo = joint_velo[:,:, upper_body].view(*motion.shape[:2], -1) upper_emb = torch.cat([UP_positions, UP_6d_rot, UP_joint_velo], dim=-1) return upper_emb, lower_emb