# This code is based on https://github.com/openai/guided-diffusion """ This code started out as a PyTorch port of Ho et al's diffusion models: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. """ import enum import math import numpy as np import torch import torch as th from copy import deepcopy from torch import optim, nn from diffusion.nn import mean_flat, sum_flat from diffusion.losses import normal_kl, discretized_gaussian_log_likelihood from data_loaders.humanml.scripts import motion_process import utils.model_util as model_util # # # obj_verts, obj_faces import utils.model_utils as model_utils import utils.common_utils as common_utils from manopth.manolayer import ManoLayer from sample.reconstruct_data import calculate_disp_quants_batched, calculate_disp_quants_batched_v2 def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.): """ Get a pre-defined beta schedule for the given name. The beta schedule library consists of beta schedules which remain similar in the limit of num_diffusion_timesteps. Beta schedules may be added, but should not be removed or changed once they are committed to maintain backwards compatibility. """ if schedule_name == "linear": # Linear schedule from Ho et al, extended to work for any number of # diffusion steps. # scale scale = scale_betas * 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 return np.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, ) else: raise NotImplementedError(f"unknown beta schedule: {schedule_name}") ## betas for alpha bar ## def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) class ModelMeanType(enum.Enum): """ Which type of output the model predicts. """ PREVIOUS_X = enum.auto() # the model predicts x_{t-1} START_X = enum.auto() # the model predicts x_0 EPSILON = enum.auto() # the model predicts epsilon class ModelVarType(enum.Enum): """ What is used as the model's output variance. The LEARNED_RANGE option has been added to allow the model to predict values between FIXED_SMALL and FIXED_LARGE, making its job easier. """ LEARNED = enum.auto() FIXED_SMALL = enum.auto() FIXED_LARGE = enum.auto() LEARNED_RANGE = enum.auto() class LossType(enum.Enum): MSE = enum.auto() # use raw MSE loss (and KL when learning variances) RESCALED_MSE = ( enum.auto() ) # use raw MSE loss (with RESCALED_KL when learning variances) KL = enum.auto() # use the variational lower-bound RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB def is_vb(self): return self == LossType.KL or self == LossType.RESCALED_KL class GaussianDiffusion: # Gaussian """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. ''' Load statistics ''' avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) ## self.avg_joints_rel, self.std_joints_rel ## self.avg_joints_dists, self.std_joints_dists self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() self.std_joints_rel = torch.from_numpy(std_joints_rel).float() self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # self.maxx_rel, minn_rel, maxx_dists, minn_dists # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() maxx_rel = rel_dists_stats['maxx_rel'] minn_rel = rel_dists_stats['minn_rel'] maxx_dists = rel_dists_stats['maxx_dists'] minn_dists = rel_dists_stats['minn_dists'] self.maxx_rel = torch.from_numpy(maxx_rel).float() self.minn_rel = torch.from_numpy(minn_rel).float() self.maxx_dists = torch.from_numpy(maxx_dists).float() self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" avg_jts = np.load(avg_jts_fn, allow_pickle=True) std_jts = np.load(std_jts_fn, allow_pickle=True) # self.avg_jts, self.std_jts # self.avg_jts = torch.from_numpy(avg_jts).float() self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['rel_base_pts_to_rhand_joints'].shape[0] assert t.shape == (B,) input_rel_base_pts_to_rhand_joints = x['rel_base_pts_to_rhand_joints'] model_output = model(x, self._scale_timesteps(t)) # if 'inpainting_mask' in model_kwargs['y'].keys() and 'inpainted_motion' in model_kwargs['y'].keys(): # inpainting_mask, inpainted_motion = model_kwargs['y']['inpainting_mask'], model_kwargs['y']['inpainted_motion'] # assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for mow!' # assert model_output.shape == inpainting_mask.shape == inpainted_motion.shape # model_output = (model_output * ~inpainting_mask) + (inpainted_motion * inpainting_mask) # # print('model_output', model_output.shape, model_output) # # print('inpainting_mask', inpainting_mask.shape, inpainting_mask[0,0,0,:]) # # print('inpainted_motion', inpainted_motion.shape, inpainted_motion) # if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: # assert model_output.shape == (B, C * 2, *x.shape[2:]) # model_output, model_var_values = th.split(model_output, C, dim=1) # if self.model_var_type == ModelVarType.LEARNED: # model_log_variance = model_var_values # model_variance = th.exp(model_log_variance) # else: # min_log = _extract_into_tensor( # self.posterior_log_variance_clipped, t, x.shape # ) # max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # # The model_var_values is [-1, 1] for [min_var, max_var]. # frac = (model_var_values + 1) / 2 # model_log_variance = frac * max_log + (1 - frac) * min_log # model_variance = th.exp(model_log_variance) # else: ### === model variance and log_variance === ### # model_variance, model_log_variance = { # # for fixedlarge, we set the initial (log-)variance like so # # to get a better decoder log likelihood. # ModelVarType.FIXED_LARGE: ( # np.append(self.posterior_variance[1], self.betas[1:]), # np.log(np.append(self.posterior_variance[1], self.betas[1:])), # ), # ModelVarType.FIXED_SMALL: ( # self.posterior_variance, # self.posterior_log_variance_clipped, # model log variance # # ), # }[self.model_var_type] ### === model variance and log_variance === ### model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # print('model_variance', model_variance) # print('model_log_variance',model_log_variance) # print('self.posterior_variance', self.posterior_variance) # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped) # print('self.model_var_type', self.model_var_type) pred_rel_base_pts_to_rhand_joints = model_output['dec_rel'] ## input relative positions # model_variance = _extract_into_tensor(model_variance, t, input_rel_base_pts_to_rhand_joints.shape) model_log_variance = _extract_into_tensor(model_log_variance, t, input_rel_base_pts_to_rhand_joints.shape) # def process_xstart(x): # if denoised_fn is not None: # x = denoised_fn(x) # if clip_denoised: # # print('clip_denoised', clip_denoised) # return x.clamp(-1, 1) # return x # if self.model_mean_type == ModelMeanType.PREVIOUS_X: # pred_xstart = process_xstart( # self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) # ) # model_mean = model_output # elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US! # if self.model_mean_type == ModelMeanType.START_X: # pred_xstart = process_xstart(model_output) # else: # pred_xstart = process_xstart( # self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) # ) model_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_rel_base_pts_to_rhand_joints, x_t=input_rel_base_pts_to_rhand_joints, t=t ) # else: # raise NotImplementedError(self.model_mean_type) assert ( model_mean.shape == model_log_variance.shape == pred_rel_base_pts_to_rhand_joints.shape == input_rel_base_pts_to_rhand_joints.shape ) return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_rel_base_pts_to_rhand_joints, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. # mean or noise :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, # starting clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # bsz x ws x nnj x nnb x 3 # noise = th.randn_like(x['rel_base_pts_to_rhand_joints']) # print('const_noise', const_noise) if const_noise: noise = noise[[0]].repeat(x['rel_base_pts_to_rhand_joints'].shape[0], 1, 1, 1, 1) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['rel_base_pts_to_rhand_joints'].shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x['rel_base_pts_to_rhand_joints'], t, model_kwargs=model_kwargs ) # print('mean', out["mean"].shape, out["mean"]) # print('log_variance', out["log_variance"].shape, out["log_variance"]) # print('nonzero_mask', nonzero_mask.shape, nonzero_mask) # sample # why the out only remember relative sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] if 'sampled_base_pts_nearest_obj_pc' in init_image: ambient_init_image = { 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) # if skip_timesteps and init_image is None: # rhand_joints = th.zeros_like(img) # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") ''' No rhandjoitns here ''' # if rhand_joints is not None: # # largest jvariance for sampling? # [->t] add noise + [<-t] remove noise # my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # pert_rhand_joints = self.q_sample(rhand_joints, my_t, img) # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # # bsz x ws x nnjts x nnb # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) # ''' Relative positions and distances normalization, strategy 2 ''' # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - self.avg_joints_rel.unsqueeze(-2)) / self.std_joints_rel.unsqueeze(-2) # dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - self.avg_joints_dists.unsqueeze(-1)) / self.std_joints_dists.unsqueeze(-1) # ''' Relative positions and distances normalization, strategy 2 ''' ''' No rhandjoitns here ''' if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # bsz x ws x nnjts x 3 # ## sample perturbed joints ## gaussian noise add to rel and dists ''' stra 1 -> independent noise ''' rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, my_t, noise=rel_noise) print(f"==== Adding noise constant for base points... ====") ''' stra 2 -> same noise ''' # rel_noise = rel_noise[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, rel_base_pts_to_rhand_joints.size(2), 1, 1).contiguous() # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, my_t, noise=rel_noise) ''' stra 2 -> same noise ''' dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, my_t, noise=dist_noise) ''' Normalization Strategy 4 ''' elif self.denoising_stra == "motion_to_rep": my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] joints_noise = th.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, my_t, noise=joints_noise) pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") input_data = { # input data 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # # p_sample_with_grad sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample # out = sample_fn( model, input_data, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, const_noise=const_noise, ) # yield out img = out["sample"] # bsz x ws x nnj x nnb x 3 # ''' Relative positions and distances normalization, strategy 2 ''' # img = img * self.std_joints_rel.unsqueeze(-2) + self.avg_joints_rel.unsqueeze(-2) ''' Relative positions and distances normalization, strategy 2 ''' ''' Relative positions and distances normalization, strategy 4 ''' # img = img * (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' if self.denoising_stra == "rep": ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_rel = init_image['per_frame_avg_joints_rel'] per_frame_std_joints_rel = init_image['per_frame_std_joints_rel'] # std joints rel # # per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] # per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] img = img * per_frame_std_joints_rel + per_frame_avg_joints_rel ''' Relative positions and distances normalization, strategy 3 ''' ''' sampled base pts based joints ''' ## img + base_pts.unsqueeze(1).unsqueeze(1) ## # --> decrease the related potential when gathering information for sampled_rhand_joints --> # #### === todo -> from sampled rel to rhandjoints == sampled_base_pts_based_joints = img + base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnj x nnb x 3 sampled_rhand_joints = sampled_base_pts_based_joints.mean(dim=-2) ## a simple averaging strategy # sampled_rhand_joints = sampled_base_pts_based_joints[..., 0, :] ## a simple averaging strategy if self.inter_optim: # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) obj_verts, obj_normals, obj_faces = init_image["obj_verts"], init_image["obj_normals"], init_image["obj_faces"] # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) # optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces) ### === sampled rhand joints === ### # sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) # optimize_sampled_hand_joints_wobj_v2 sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj_v2(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) rel_base_pts_to_pert_rhand_joints = sampled_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnjts x nnb dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) ## denoi; if self.denoising_stra == "rep": ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] # rel base_pts to joints # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - per_frame_avg_joints_rel ) / per_frame_std_joints_rel dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - per_frame_avg_joints_dists_rel) / per_frame_std_joints_dists_rel ''' Relative positions and distances normalization, strategy 3 ''' ## denoise in a regular representations space ## # if self.denoising_stra == "motion_to_rep": # sampled_rhand_joints_normed = (sampled_rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) ''' Relative positions and distances normalization, strategy 4 ''' # rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints / (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) # dist_base_pts_to_pert_rhand_joints = dist_base_pts_to_pert_rhand_joints / (self.maxx_dists - self.minn_dists).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).squeeze(-1).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' input_data = { 'sampled_rhand_joints': sampled_rhand_joints, # 'sampled_rhand_joints': pert_rhand_joints, # 'sampled_rhand_joints': base_pts.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1), 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'rel_base_pts_to_rhand_joints': img, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## training losses def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] # bsz x ws x nnjts x nnbase x 3 # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # bsz x ws x nnjts x nnbase # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] if 'sampled_base_pts_nearest_obj_pc' in x_start: ambient_xstart_dict = { 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], } ''' Permute rhand joints for permuting relative position values and distance values ''' # if noise is None: # rhand_joints_noise = th.randn_like(rhand_joints) # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # pert_rhand_joints = self.q_sample(rhand_joints, t, noise=rhand_joints_noise) # # bsz x ws x nnjts x nnb x 3 # # # print(f"pert_rhand_joints: {pert_rhand_joints.size()}, base_pts: {base_pts.size()}") # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # # bsz x ws x nnjts x nnb # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) # ''' Relative positions and distances normalization, strategy 2 ''' # # bsz x ws x nf x nb x 3 # # 1 x nf x 3 # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - self.avg_joints_rel.unsqueeze(-2).unsqueeze(1).to(pert_rhand_joints.device)) / self.std_joints_rel.unsqueeze(-2).unsqueeze(1).to(pert_rhand_joints.device) # dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - self.avg_joints_dists.unsqueeze(1).unsqueeze(-1).to(pert_rhand_joints.device)) / self.std_joints_dists.unsqueeze(1).unsqueeze(-1).to(pert_rhand_joints.device) # ''' Relative positions and distances normalization, strategy 2 ''' ''' Permute rhand joints for permuting relative position values and distance values ''' ''' GET rel and dists ''' # denoising if self.denoising_stra == "rep": # bsz x ws x nnj x nnb x 3 # if noise is None: rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # per_frame_avg_joints_rel = x_start['per_frame_avg_joints_rel'] # per_frame_std_joints_rel = x_start['per_frame_std_joints_rel'] # std joints rel # # denorm_rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints * per_frame_std_joints_rel) + per_frame_avg_joints_rel # denorm_pert_rel_base_pts_to_rhand_joints = # norm_rel_noise = (rel_noise - ) ''' stra 1 -> independent noise ''' # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) ''' stra 2 -> same noise ''' rel_noise = rel_noise[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, rel_base_pts_to_rhand_joints.size(2), 1, 1).contiguous() rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) ''' stra 2 -> same noise ''' # normalization for each framej -> the relative positions and signed distances # dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, t, noise=dist_noise) elif self.denoising_stra == "motion_to_rep": # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' input_data = { 'base_pts': base_pts.clone(), 'base_normals': base_normals.clone(), 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints.clone(), 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(), } if 'sampled_base_pts_nearest_obj_pc' in x_start: input_data.update(ambient_xstart_dict) # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) if model_kwargs is None: model_kwargs = {} # if noise is None: # noise = th.randn_like(x_start) # x_t = self.q_sample(x_start, t, noise=noise) terms = {} model_output = model(input_data, self._scale_timesteps(t).clone()) # model_output ---> model x_t # model # if self.model_var_type in [ # ModelVarType.LEARNED, # ModelVarType.LEARNED_RANGE, # ]: # B, C = x_t.shape[:2] # assert model_output.shape == (B, C * 2, *x_t.shape[2:]) # model_output, model_var_values = th.split(model_output, C, dim=1) # # Learn the variance using the variational bound, but don't let # # it affect our mean prediction. # frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) # terms["vb"] = self._vb_terms_bpd( # model=lambda *args, r=frozen_out: r, # x_start=x_start, # x_t=x_t, # t=t, # clip_denoised=False, # )["output"] # if self.loss_type == LossType.RESCALED_MSE: # # Divide by 1000 for equivalence with initial implementation. # # Without a factor of 1/1000, the VB term hurts the MSE term. # terms["vb"] *= self.num_timesteps / 1000.0 # target = { # # q posterior mean variance # # # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( # # x_start=x_start, x_t=x_t, t=t # # )[0], # ModelMeanType.START_X: x_start, # # ModelMeanType.EPSILON: noise, # }[self.model_mean_type] # model mean type --> mean type # # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] target = x_start target_rel_base_pts_to_jts = target['rel_base_pts_to_rhand_joints'] target_dist_base_pts_to_jts = target['dist_base_pts_to_rhand_joints'] dec_rel = model_output['dec_rel'] dec_dist = model_output['dec_dist'] # print(f'target_rel_base_pts_to_jts: {target_rel_base_pts_to_jts.size()}, target_dist_base_pts_to_jts: {target_dist_base_pts_to_jts.size()}, dec_rel: {dec_rel.size()}, dec_dist: {dec_dist.size()}') # terms['rot_mse'] = torch.sum( # (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 # ).mean() + torch.mean( # (target_dist_base_pts_to_jts - dec_dist) ** 2, dim=-1 # ) terms['rel_pred'] = torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ).mean() terms['dist_pred'] = ((target_dist_base_pts_to_jts - dec_dist) ** 2).mean() terms['rot_mse'] = (torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ) + (target_dist_base_pts_to_jts - dec_dist) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # rel to joints? # terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_out_in ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class GaussianDiffusionV2: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. ''' Load statistics ''' avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) ## self.avg_joints_rel, self.std_joints_rel ## self.avg_joints_dists, self.std_joints_dists self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() self.std_joints_rel = torch.from_numpy(std_joints_rel).float() self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # self.maxx_rel, minn_rel, maxx_dists, minn_dists # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() maxx_rel = rel_dists_stats['maxx_rel'] minn_rel = rel_dists_stats['minn_rel'] maxx_dists = rel_dists_stats['maxx_dists'] minn_dists = rel_dists_stats['minn_dists'] self.maxx_rel = torch.from_numpy(maxx_rel).float() self.minn_rel = torch.from_numpy(minn_rel).float() self.maxx_dists = torch.from_numpy(maxx_dists).float() self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" avg_jts = np.load(avg_jts_fn, allow_pickle=True) std_jts = np.load(std_jts_fn, allow_pickle=True) # self.avg_jts, self.std_jts # self.avg_jts = torch.from_numpy(avg_jts).float() self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['rel_base_pts_to_rhand_joints'].shape[0] assert t.shape == (B,) ## rel baes pts to rhand joints ## ## input base rel joints ## ## bsz x ws x nnj x nnb x 3 ## input_rel_base_pts_to_rhand_joints = x['rel_base_pts_to_rhand_joints'] # how to sample from the input_dist_base_pts_to_rhand_joints = x['dist_base_pts_to_rhand_joints'] # regularize the velocity space # input_rhand_joints_based_on_base_pts = input_rel_base_pts_to_rhand_joints * x['per_frame_std_joints_rel'] + x['per_frame_avg_joints_rel'] ## bsz x ws x nnj x nnb x 3 ## input_rhand_joints_based_on_base_pts = input_rhand_joints_based_on_base_pts + x['base_pts'].unsqueeze(1).unsqueeze(1) # input_rhand_joints = model_output, e_model_output = model(x, self._scale_timesteps(t)) # }[self.model_var_type] ### === model variance and log_variance === ### model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # print('model_variance', model_variance) # print('model_log_variance',model_log_variance) # print('self.posterior_variance', self.posterior_variance) # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped) # print('self.model_var_type', self.model_var_type) pred_rel_base_pts_to_rhand_joints = model_output['dec_rel'] pred_rel_base_pts_to_rhand_joints_denormed = pred_rel_base_pts_to_rhand_joints * x['per_frame_std_joints_rel'] + x['per_frame_avg_joints_rel'] # pred_rhand_joints_based_on_base_pts: bs zx ws x nnj x nnb x 3 ## pred_rhand_joints_based_on_base_pts = pred_rel_base_pts_to_rhand_joints_denormed + x['base_pts'].unsqueeze(1).unsqueeze(1) pred_rhand_joints = torch.mean(pred_rhand_joints_based_on_base_pts, dim=-2) # bsz x ws x nnj x 3 # # pred_rhand_joints_normed: bsz x ws x nnj x 3 ## pred_rhand_joints_normed = (pred_rhand_joints - self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device)) / self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) input_rhand_joints_normed = (x['rhand_joints'] - self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device)) / self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) ## input relative positions # model_variance = _extract_into_tensor(model_variance, t, x['rhand_joints'].shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x['rhand_joints'].shape) # rhand_joints_normed_mean, _, _ = self.q_posterior_mean_variance( # x_start=input_rhand_joints_normed, x_t=pred_rhand_joints_normed, t=t # ) rhand_joints_normed_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_rhand_joints_normed, x_t=input_rhand_joints_normed, t=t ) # bsz x ws x nnj x 3 # # denormed # ## ==== just for debugging ==== ## # rhand_joints_mean = rhand_joints_normed_mean * self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) + self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device) rhand_joints_mean = pred_rhand_joints_normed * self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) + self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device) # model_mena: bsz x ws x nnj x nnb x 3 # model_mean = rhand_joints_mean.unsqueeze(-2) - x['base_pts'].unsqueeze(1).unsqueeze(1) model_mean = (model_mean - x['per_frame_avg_joints_rel']) / x['per_frame_std_joints_rel'] ## ==== just for debugging ==== ## model_mean = rhand_joints_normed_mean # model_mean = pred_rhand_joints_normed # def process_xstart(x): # if denoised_fn is not None: # x = denoised_fn(x) # if clip_denoised: # # print('clip_denoised', clip_denoised) # return x.clamp(-1, 1) # return x # if self.model_mean_type == ModelMeanType.PREVIOUS_X: # pred_xstart = process_xstart( # self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) # ) # model_mean = model_output # elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US! # if self.model_mean_type == ModelMeanType.START_X: # pred_xstart = process_xstart(model_output) # else: # pred_xstart = process_xstart( # self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) # ) # model_mean, _, _ = self.q_posterior_mean_variance( # x_start=pred_rel_base_pts_to_rhand_joints, x_t=input_rel_base_pts_to_rhand_joints, t=t # ) # else: # raise NotImplementedError(self.model_mean_type) # assert ( # model_mean.shape == model_log_variance.shape == pred_rel_base_pts_to_rhand_joints.shape == input_rel_base_pts_to_rhand_joints.shape # ) return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_rel_base_pts_to_rhand_joints, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, # starting clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # bsz x ws x nnj x nnb x 3 # noise = th.randn_like(x['rhand_joints']) # print('const_noise', const_noise) if const_noise: noise = noise[[0]].repeat(x['rhand_joints'].shape[0], 1, 1, 1, 1) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['rhand_joints'].shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x['rhand_joints'], t, model_kwargs=model_kwargs ) # print('mean', out["mean"].shape, out["mean"]) # print('log_variance', out["log_variance"].shape, out["log_variance"]) # print('nonzero_mask', nonzero_mask.shape, nonzero_mask) # sample # why the out only remember relative sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise # bsz x ws x nnj x 3 # # denormed # rhand_joints_mean = sample * self.std_jts.unsqueeze(0).to(sample.device) + self.avg_jts.unsqueeze(0).to(sample.device) # model_mena: bsz x ws x nnj x nnb x 3 # sample = rhand_joints_mean.unsqueeze(-2) - x['base_pts'].unsqueeze(1).unsqueeze(1) sample = (sample - x['per_frame_avg_joints_rel']) / x['per_frame_std_joints_rel'] return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] if 'sampled_base_pts_nearest_obj_pc' in init_image: ambient_init_image = { 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], } init_image_avg_std_stats = { 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) # if skip_timesteps and init_image is None: # rhand_joints = th.zeros_like(img) # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") ''' No rhandjoitns here ''' # if rhand_joints is not None: # # largest jvariance for sampling? # [->t] add noise + [<-t] remove noise # my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # pert_rhand_joints = self.q_sample(rhand_joints, my_t, img) # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # # bsz x ws x nnjts x nnb # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) # ''' Relative positions and distances normalization, strategy 2 ''' # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - self.avg_joints_rel.unsqueeze(-2)) / self.std_joints_rel.unsqueeze(-2) # dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - self.avg_joints_dists.unsqueeze(-1)) / self.std_joints_dists.unsqueeze(-1) # ''' Relative positions and distances normalization, strategy 2 ''' ''' No rhandjoitns here ''' if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) # xstart ne pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 -> rel_bae_pts to rhand joints # dist rel base pts to pert rhand joints # dist_rel_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# pert_rhand_joints_disp = pert_rhand_joints[:, 1:, :, :] - pert_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_pert_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * pert_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_pert_rhand_joints_vt_normal = pert_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_pert_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_pert_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_pert_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_pert_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # rel to base along normals # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ### # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] # dist_rel_base_pts_to_pert_rhand_joints: bsz x ws x nnj x nnb # --> rel and dists ## dist_rel_base_pts_to_pert_rhand_joints = (dist_rel_base_pts_to_pert_rhand_joints - init_image['per_frame_avg_joints_dists_rel'] ) / init_image['per_frame_std_joints_dists_rel'] ## dist pts to pert joints ## dist_base_pts_to_pert_rhand_joints = dist_rel_base_pts_to_pert_rhand_joints rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints # rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, my_t, noise=rel_noise) # dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) # dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, my_t, noise=dist_noise) ''' Normalization Strategy 4 ''' elif self.denoising_stra == "motion_to_rep": my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] joints_noise = th.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, my_t, noise=joints_noise) pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") input_data = { 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, } ) # input input_data.update(init_image_avg_std_stats) model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # # p_sample_with_grad sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample # out = sample_fn( model, input_data, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, const_noise=const_noise, ) # yield out img = out["sample"] # bsz x ws x nnj x nnb x 3 # ''' Relative positions and distances normalization, strategy 2 ''' # img = img * self.std_joints_rel.unsqueeze(-2) + self.avg_joints_rel.unsqueeze(-2) ''' Relative positions and distances normalization, strategy 2 ''' ''' Relative positions and distances normalization, strategy 4 ''' # img = img * (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' if self.denoising_stra == "rep": # ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_rel = init_image['per_frame_avg_joints_rel'] per_frame_std_joints_rel = init_image['per_frame_std_joints_rel'] # std joints rel # # per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] # per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] img = img * per_frame_std_joints_rel + per_frame_avg_joints_rel ''' Relative positions and distances normalization, strategy 3 ''' ''' sampled base pts based joints ''' ## img + base_pts.unsqueeze(1).unsqueeze(1) ## # --> decrease the related potential when gathering information for sampled_rhand_joints --> # #### === todo -> from sampled rel to rhandjoints == sampled_base_pts_based_joints = img + base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnj x nnb x 3 sampled_rhand_joints = sampled_base_pts_based_joints.mean(dim=-2) ## a simple averaging strategy # sampled_rhand_joints = sampled_base_pts_based_joints[..., 0, :] ## a simple averaging strategy if self.inter_optim: # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) obj_verts, obj_normals, obj_faces = init_image["obj_verts"], init_image["obj_normals"], init_image["obj_faces"] # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) # optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces) ### === sampled rhand joints === ### # sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) # optimize_sampled_hand_joints_wobj_v2 sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj_v2(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) rel_base_pts_to_pert_rhand_joints = sampled_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnjts x nnb dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# sampled_rhand_joints_disp = sampled_rhand_joints[:, 1:, :, :] - sampled_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_sampled_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * sampled_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_sampled_rhand_joints_vt_normal = sampled_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_sampled_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_sampled_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_sampled_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_sampled_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_sampled_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ## denoi; if self.denoising_stra == "rep": ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] # rel base_pts to joints # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - per_frame_avg_joints_rel ) / per_frame_std_joints_rel dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - per_frame_avg_joints_dists_rel) / per_frame_std_joints_dists_rel ''' Relative positions and distances normalization, strategy 3 ''' ## denoise in a regular representations space ## # if self.denoising_stra == "motion_to_rep": # sampled_rhand_joints_normed = (sampled_rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) ''' Relative positions and distances normalization, strategy 4 ''' # rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints / (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) # dist_base_pts_to_pert_rhand_joints = dist_base_pts_to_pert_rhand_joints / (self.maxx_dists - self.minn_dists).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).squeeze(-1).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' input_data = { 'sampled_rhand_joints': sampled_rhand_joints, # 'sampled_rhand_joints': pert_rhand_joints, # 'sampled_rhand_joints': base_pts.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1), 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'rel_base_pts_to_rhand_joints': img, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update(init_image_avg_std_stats) yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## training losses def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] ### enc model.model enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] # normalization strategy for joints and that for the representation values # if 'sampled_base_pts_nearest_obj_pc' in x_start: ambient_xstart_dict = { 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], } if 'e_disp_rel_to_base_along_normals' in x_start: e_disp_xstart_dict = { # x_start['e_disp_rel_to_base_along_normals'], x_start['e_disp_rel_to_baes_vt_normals'], 'per_frame_avg_disp_along_normals': x_start['per_frame_avg_disp_along_normals'], 'per_frame_std_disp_along_normals': x_start['per_frame_std_disp_along_normals'], 'per_frame_avg_disp_vt_normals': x_start['per_frame_avg_disp_vt_normals'], 'per_frame_std_disp_vt_normals': x_start['per_frame_std_disp_vt_normals'], 'e_disp_rel_to_base_along_normals': x_start['e_disp_rel_to_base_along_normals'], 'e_disp_rel_to_baes_vt_normals': x_start['e_disp_rel_to_baes_vt_normals'], } ''' Permute rhand joints for permuting relative position values and distance values ''' # if noise is None: # rhand_joints_noise = th.randn_like(rhand_joints) # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # pert_rhand_joints = self.q_sample(rhand_joints, t, noise=rhand_joints_noise) # # bsz x ws x nnjts x nnb x 3 # # # print(f"pert_rhand_joints: {pert_rhand_joints.size()}, base_pts: {base_pts.size()}") # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # # bsz x ws x nnjts x nnb # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) # ''' Relative positions and distances normalization, strategy 2 ''' # # bsz x ws x nf x nb x 3 # # 1 x nf x 3 # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - self.avg_joints_rel.unsqueeze(-2).unsqueeze(1).to(pert_rhand_joints.device)) / self.std_joints_rel.unsqueeze(-2).unsqueeze(1).to(pert_rhand_joints.device) # dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - self.avg_joints_dists.unsqueeze(1).unsqueeze(-1).to(pert_rhand_joints.device)) / self.std_joints_dists.unsqueeze(1).unsqueeze(-1).to(pert_rhand_joints.device) # ''' Relative positions and distances normalization, strategy 2 ''' ''' Permute rhand joints for permuting relative position values and distance values ''' ''' GET rel and dists ''' # denoising if self.denoising_stra == "rep": # bsz x ws x nnj x nnb x 3 # # avg_jts: 1 x nnj x 3 # std_jts: 1 x nnj x 3 # rhand_joints: bsz x ws x nnj x 3; normalize rhand joitns # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, t, noise=noise_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # avg_pert_normed_rhand_joints = pert_normed_rhand_joints[0].mean(dim=0).mean(dim=0) # avg_pert_rhand_joints = pert_rhand_joints[0].mean(dim=0).mean(dim=0) # print(f"avg_normed_jts: {avg_pert_normed_rhand_joints}, avg_pert_rhand_joints: {avg_pert_rhand_joints}") # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 -> rel_bae_pts to rhand joints # # bsz x ws x nnj x nnb x 3 ## # avg_rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints[0].mean(dim=0).mean(dim=0)[0] # print(f"Before normalization: avg_rel {avg_rel_base_pts_to_pert_rhand_joints}") dist_rel_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# pert_rhand_joints_disp = pert_rhand_joints[:, 1:, :, :] - pert_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_pert_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * pert_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_pert_rhand_joints_vt_normal = pert_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_pert_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_pert_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_pert_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_pert_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] ### # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # dist_rel_base_pts_to_pert_rhand_joints: bsz x ws x nnj x nnb # --> rel and dists ## dist_rel_base_pts_to_pert_rhand_joints = (dist_rel_base_pts_to_pert_rhand_joints - x_start['per_frame_avg_joints_dists_rel'] ) / x_start['per_frame_std_joints_dists_rel'] dist_base_pts_to_pert_rhand_joints = dist_rel_base_pts_to_pert_rhand_joints rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints # # bsz x ws x nnj x nnb x 3 ## # avg_rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints[0].mean(dim=0).mean(dim=0)[0] # print(f"After normalization: avg_rel {avg_rel_base_pts_to_pert_rhand_joints}") # if noise is None: # # rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # # if # ### # # per_frame_avg_joints_rel = x_start['per_frame_avg_joints_rel'] # # per_frame_std_joints_rel = x_start['per_frame_std_joints_rel'] # std joints rel # # # denorm_rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints * per_frame_std_joints_rel) + per_frame_avg_joints_rel # # denorm_pert_rel_base_pts_to_rhand_joints = # # norm_rel_noise = (rel_noise - ) # ''' stra 1 -> independent noise ''' # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) # ''' stra 2 -> same noise ''' # rel_noise = rel_noise[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, rel_base_pts_to_rhand_joints.size(2), 1, 1).contiguous() # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) # ''' stra 2 -> same noise ''' # # normalization for each framej -> the relative positions and signed distances # # dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) # dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, t, noise=dist_noise) elif self.denoising_stra == "motion_to_rep": # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' input_data = { 'base_pts': base_pts.clone(), 'base_normals': base_normals.clone(), 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints.clone(), 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(), } if 'sampled_base_pts_nearest_obj_pc' in x_start: input_data.update(ambient_xstart_dict) # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # bsz x ws - 1 x nnj x nnb # input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, } ) # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) if model_kwargs is None: model_kwargs = {} # if noise is None: # noise = th.randn_like(x_start) # x_t = self.q_sample(x_start, t, noise=noise) terms = {} model_output, e_model_output = model(input_data, self._scale_timesteps(t).clone()) # model_output ---> model x_t # model # if self.model_var_type in [ # ModelVarType.LEARNED, # ModelVarType.LEARNED_RANGE, # ]: # B, C = x_t.shape[:2] # assert model_output.shape == (B, C * 2, *x_t.shape[2:]) # model_output, model_var_values = th.split(model_output, C, dim=1) # # Learn the variance using the variational bound, but don't let # # it affect our mean prediction. # frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) # terms["vb"] = self._vb_terms_bpd( # model=lambda *args, r=frozen_out: r, # x_start=x_start, # x_t=x_t, # t=t, # clip_denoised=False, # )["output"] # if self.loss_type == LossType.RESCALED_MSE: # # Divide by 1000 for equivalence with initial implementation. # # Without a factor of 1/1000, the VB term hurts the MSE term. # terms["vb"] *= self.num_timesteps / 1000.0 # target = { # # q posterior mean variance # # # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( # # x_start=x_start, x_t=x_t, t=t # # )[0], # ModelMeanType.START_X: x_start, # # ModelMeanType.EPSILON: noise, # }[self.model_mean_type] # model mean type --> mean type # # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] target = x_start target_rel_base_pts_to_jts = target['rel_base_pts_to_rhand_joints'] target_dist_base_pts_to_jts = target['dist_base_pts_to_rhand_joints'] dec_rel = model_output['dec_rel'] dec_dist = model_output['dec_dist'] # e_model_output, dec_e_along_normalss, dec_e_vt_normals # dec_e_along_normals = e_model_output['dec_e_along_normals'] dec_e_vt_normals = e_model_output['dec_e_vt_normals'] # print(f'target_rel_base_pts_to_jts: {target_rel_base_pts_to_jts.size()}, target_dist_base_pts_to_jts: {target_dist_base_pts_to_jts.size()}, dec_rel: {dec_rel.size()}, dec_dist: {dec_dist.size()}') # terms['rot_mse'] = torch.sum( # (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 # ).mean() + torch.mean( # (target_dist_base_pts_to_jts - dec_dist) ** 2, dim=-1 # ) # rel pred; dist pred; # terms['rel_pred'] = torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ).mean() terms['dist_pred'] = ((target_dist_base_pts_to_jts - dec_dist) ** 2).mean() rel_pred_loss = torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) dist_pred_loss = ((target_dist_base_pts_to_jts - dec_dist) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) dec_e_along_normals_loss = ((dec_e_along_normals - x_start['e_disp_rel_to_base_along_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) dec_e_vt_normals_loss = ((dec_e_vt_normals - x_start['e_disp_rel_to_baes_vt_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['rel_pred_loss'] = rel_pred_loss terms['dist_pred_loss'] = dist_pred_loss terms['dec_e_along_normals_loss'] = dec_e_along_normals_loss terms['dec_e_vt_normals_loss'] = dec_e_vt_normals_loss terms['rot_mse'] = (torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ) + (target_dist_base_pts_to_jts - dec_dist) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # x_start['e_disp_rel_to_base_along_normals'], x_start['e_disp_rel_to_baes_vt_normals'], terms['rot_mse'] = terms['rot_mse'] + ((dec_e_along_normals - x_start['e_disp_rel_to_base_along_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) + ((dec_e_vt_normals - x_start['e_disp_rel_to_baes_vt_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # rel to joints? # terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_out_in ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class GaussianDiffusionV3: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. ''' Load statistics ''' avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) ## self.avg_joints_rel, self.std_joints_rel ## self.avg_joints_dists, self.std_joints_dists self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() self.std_joints_rel = torch.from_numpy(std_joints_rel).float() self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # self.maxx_rel, minn_rel, maxx_dists, minn_dists # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() maxx_rel = rel_dists_stats['maxx_rel'] minn_rel = rel_dists_stats['minn_rel'] maxx_dists = rel_dists_stats['maxx_dists'] minn_dists = rel_dists_stats['minn_dists'] self.maxx_rel = torch.from_numpy(maxx_rel).float() self.minn_rel = torch.from_numpy(minn_rel).float() self.maxx_dists = torch.from_numpy(maxx_dists).float() self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" avg_jts = np.load(avg_jts_fn, allow_pickle=True) std_jts = np.load(std_jts_fn, allow_pickle=True) # self.avg_jts, self.std_jts # self.avg_jts = torch.from_numpy(avg_jts).float() self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['rel_base_pts_to_rhand_joints'].shape[0] assert t.shape == (B,) ## rel baes pts to rhand joints ## ## input base rel joints ## ## bsz x ws x nnj x nnb x 3 ## input_rel_base_pts_to_rhand_joints = x['rel_base_pts_to_rhand_joints'] # how to sample from the input_dist_base_pts_to_rhand_joints = x['dist_base_pts_to_rhand_joints'] # regularize the velocity space # input_rhand_joints_based_on_base_pts = input_rel_base_pts_to_rhand_joints * x['per_frame_std_joints_rel'] + x['per_frame_avg_joints_rel'] ## bsz x ws x nnj x nnb x 3 ## input_rhand_joints_based_on_base_pts = input_rhand_joints_based_on_base_pts + x['base_pts'].unsqueeze(1).unsqueeze(1) joints_scaling_factor = 5. # input_rhand_joints = x['rhand_joints'] # normed_rhand_joints = (input_rhand_joints - self.avg_jts.unsqueeze(0).to(input_rhand_joints.device)) / self.std_jts.unsqueeze(0).to(input_rhand_joints.device) normed_rhand_joints = x['pert_rhand_joints'] scaled_normed_rhand_joints = normed_rhand_joints - self.avg_exp_rhand_joints.unsqueeze(1) scaled_rhand_joints = scaled_normed_rhand_joints * joints_scaling_factor # input_rhand_joints = # avg_exp_rhand_joints joint_seq_output, model_output, e_model_output = model(x, self._scale_timesteps(t)) # }[self.model_var_type] ### === model variance and log_variance === ### model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # print('model_variance', model_variance) # print('model_log_variance',model_log_variance) # print('self.posterior_variance', self.posterior_variance) # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped) # print('self.model_var_type', self.model_var_type) pred_rel_base_pts_to_rhand_joints = model_output['dec_rel'] pred_rel_base_pts_to_rhand_joints_denormed = pred_rel_base_pts_to_rhand_joints * x['per_frame_std_joints_rel'] + x['per_frame_avg_joints_rel'] # pred_rhand_joints_based_on_base_pts: bs zx ws x nnj x nnb x 3 ## pred_rhand_joints_based_on_base_pts = pred_rel_base_pts_to_rhand_joints_denormed + x['base_pts'].unsqueeze(1).unsqueeze(1) pred_rhand_joints = torch.mean(pred_rhand_joints_based_on_base_pts, dim=-2) # bsz x ws x nnj x 3 # # pred_rhand_joints_normed: bsz x ws x nnj x 3 ## pred_rhand_joints_normed = (pred_rhand_joints - self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device)) / self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) # input_rhand_joints_normed = (x['rhand_joints'] - self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device)) / self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) input_rhand_joints_normed = x['pert_rhand_joints'] ## input relative positions # model_variance = _extract_into_tensor(model_variance, t, x['pert_rhand_joints'].shape) model_log_variance = _extract_into_tensor(model_log_variance, t, x['pert_rhand_joints'].shape) # rhand_joints_normed_mean, _, _ = self.q_posterior_mean_variance( # x_start=input_rhand_joints_normed, x_t=pred_rhand_joints_normed, t=t # ) rhand_joints_normed_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_rhand_joints_normed, x_t=input_rhand_joints_normed, t=t ) # bsz x ws x nnj x 3 # # denormed # ## ==== just for debugging ==== ## # rhand_joints_mean = rhand_joints_normed_mean * self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) + self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device) rhand_joints_mean = pred_rhand_joints_normed * self.std_jts.unsqueeze(0).to(pred_rhand_joints.device) + self.avg_jts.unsqueeze(0).to(pred_rhand_joints.device) # model_mena: bsz x ws x nnj x nnb x 3 # model_mean = rhand_joints_mean.unsqueeze(-2) - x['base_pts'].unsqueeze(1).unsqueeze(1) model_mean = (model_mean - x['per_frame_avg_joints_rel']) / x['per_frame_std_joints_rel'] ## ==== just for debugging ==== ## model_mean = rhand_joints_normed_mean # model_mean = pred_rhand_joints_normed #### Using normed joints #### model_mean, _, _ = self.q_posterior_mean_variance( x_start=joint_seq_output, x_t=normed_rhand_joints, t=t ) #### Using normed joints #### #### Using scaled joints #### # # normed_rhand_joints ### scaled_rhand_joints ### # model_mean, _, _ = self.q_posterior_mean_variance( # x_start=joint_seq_output, x_t=scaled_rhand_joints, t=t # ) #### Using scaled joints #### return { "mean": model_mean, "variance": model_variance, "log_variance": model_log_variance, "pred_xstart": pred_rel_base_pts_to_rhand_joints, } def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, # starting clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # bsz x ws x nnj x nnb x 3 # noise = th.randn_like(x['pert_rhand_joints']) # print('const_noise', const_noise) if const_noise: noise = noise[[0]].repeat(x['pert_rhand_joints'].shape[0], 1, 1, 1, 1) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['pert_rhand_joints'].shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean( cond_fn, out, x['pert_rhand_joints'], t, model_kwargs=model_kwargs ) # print('mean', out["mean"].shape, out["mean"]) # print('log_variance', out["log_variance"].shape, out["log_variance"]) # print('nonzero_mask', nonzero_mask.shape, nonzero_mask) # sample # why the out only remember relative sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise # self.avg_exp_rhand_joints #### using normed joints #### # # bsz x ws x nnj x 3 # # denormed # rhand_joints_mean = sample * self.std_jts.unsqueeze(0).to(sample.device) + self.avg_jts.unsqueeze(0).to(sample.device) #### using normed joints #### ### using scaled joints ### # joints_scaling_factor = 5. # rhand_joints_mean = sample / joints_scaling_factor # rhand_joints_mean = rhand_joints_mean + self.avg_exp_rhand_joints.unsqueeze(1) ### using scaled joints ### # model_mena: bsz x ws x nnj x nnb x 3 # sample = rhand_joints_mean.unsqueeze(-2) - x['base_pts'].unsqueeze(1).unsqueeze(1) sample = (sample - x['per_frame_avg_joints_rel']) / x['per_frame_std_joints_rel'] return {"sample": sample, "pred_xstart": out["pred_xstart"]} def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## normals ## ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] if 'sampled_base_pts_nearest_obj_pc' in init_image: ambient_init_image = { 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], } init_image_avg_std_stats = { # 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) # if skip_timesteps and init_image is None: # rhand_joints = th.zeros_like(img) # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") ''' No rhandjoitns here ''' # if rhand_joints is not None: # # largest jvariance for sampling? # [->t] add noise + [<-t] remove noise # my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # pert_rhand_joints = self.q_sample(rhand_joints, my_t, img) # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # # bsz x ws x nnjts x nnb # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) # ''' Relative positions and distances normalization, strategy 2 ''' # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - self.avg_joints_rel.unsqueeze(-2)) / self.std_joints_rel.unsqueeze(-2) # dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - self.avg_joints_dists.unsqueeze(-1)) / self.std_joints_dists.unsqueeze(-1) # ''' Relative positions and distances normalization, strategy 2 ''' ''' No rhandjoitns here ''' joints_scaling_factor = 5. if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con exp_rhand_joints = rhand_joints.view(rhand_joints.size(0) * rhand_joints.size(1), rhand_joints.size(2), 3) self.avg_jts = torch.mean(exp_rhand_joints, dim=0, keepdim=True) self.std_jts = torch.std(exp_rhand_joints, dim=0, keepdim=True) normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) ### scale rhand joints ## # rhand joints: bsz x ws x nnj x 3 # exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. self.avg_exp_rhand_joints = avg_exp_rhand_joints rhand_joints = rhand_joints - avg_exp_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints) # xstart ne pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 -> rel_bae_pts to rhand joints # dist rel base pts to pert rhand joints # dist_rel_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# pert_rhand_joints_disp = pert_rhand_joints[:, 1:, :, :] - pert_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_pert_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * pert_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_pert_rhand_joints_vt_normal = pert_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_pert_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_pert_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_pert_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_pert_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # rel to base along normals # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ### # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] # dist_rel_base_pts_to_pert_rhand_joints: bsz x ws x nnj x nnb # --> rel and dists ## dist_rel_base_pts_to_pert_rhand_joints = (dist_rel_base_pts_to_pert_rhand_joints - init_image['per_frame_avg_joints_dists_rel'] ) / init_image['per_frame_std_joints_dists_rel'] ## dist pts to pert joints ## dist_base_pts_to_pert_rhand_joints = dist_rel_base_pts_to_pert_rhand_joints rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints # rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, my_t, noise=rel_noise) # dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) # dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, my_t, noise=dist_noise) ''' Normalization Strategy 4 ''' elif self.denoising_stra == "motion_to_rep": my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] joints_noise = th.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, my_t, noise=joints_noise) pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") input_data = { 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, 'pert_rhand_joints': pert_normed_rhand_joints, # 'pert_rhand_joints': pert_scaled_rhand_joints, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, } ) # input input_data.update(init_image_avg_std_stats) model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # # p_sample_with_grad sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample # out = sample_fn( model, input_data, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, const_noise=const_noise, ) # yield out img = out["sample"] # bsz x ws x nnj x nnb x 3 # ''' Relative positions and distances normalization, strategy 2 ''' # img = img * self.std_joints_rel.unsqueeze(-2) + self.avg_joints_rel.unsqueeze(-2) ''' Relative positions and distances normalization, strategy 2 ''' ''' Relative positions and distances normalization, strategy 4 ''' # img = img * (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' if self.denoising_stra == "rep": # ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_rel = init_image['per_frame_avg_joints_rel'] per_frame_std_joints_rel = init_image['per_frame_std_joints_rel'] # std joints rel # # per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] # per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] img = img * per_frame_std_joints_rel + per_frame_avg_joints_rel ''' Relative positions and distances normalization, strategy 3 ''' ''' sampled base pts based joints ''' ## img + base_pts.unsqueeze(1).unsqueeze(1) ## # --> decrease the related potential when gathering information for sampled_rhand_joints --> # #### === todo -> from sampled rel to rhandjoints == sampled_base_pts_based_joints = img + base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnj x nnb x 3 sampled_rhand_joints = sampled_base_pts_based_joints.mean(dim=-2) ## a simple averaging strategy # sampled_rhand_joints = sampled_base_pts_based_joints[..., 0, :] ## a simple averaging strategy sampled_rhand_joints_cent = sampled_rhand_joints - avg_exp_rhand_joints.unsqueeze(1) scaled_sampled_rhand_joints = sampled_rhand_joints_cent * joints_scaling_factor normed_sampled_rhand_joints = (sampled_rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) if self.inter_optim: # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) obj_verts, obj_normals, obj_faces = init_image["obj_verts"], init_image["obj_normals"], init_image["obj_faces"] # sampled_rhand_joints = model_util.optimize_sampled_hand_joints(sampled_rhand_joints, img, None, base_pts, base_normals) # optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces) ### === sampled rhand joints === ### # sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) # optimize_sampled_hand_joints_wobj_v2 sampled_rhand_joints = model_util.optimize_sampled_hand_joints_wobj_v2(sampled_rhand_joints, img, None, base_pts, base_normals, obj_verts, obj_normals, obj_faces) rel_base_pts_to_pert_rhand_joints = sampled_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x ws x nnjts x nnb dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# sampled_rhand_joints_disp = sampled_rhand_joints[:, 1:, :, :] - sampled_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_sampled_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * sampled_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_sampled_rhand_joints_vt_normal = sampled_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_sampled_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_sampled_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_sampled_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_sampled_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_sampled_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ## denoi; if self.denoising_stra == "rep": ''' Relative positions and distances normalization, strategy 3 ''' per_frame_avg_joints_dists_rel = init_image['per_frame_avg_joints_dists_rel'] per_frame_std_joints_dists_rel = init_image['per_frame_std_joints_dists_rel'] # rel base_pts to joints # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - per_frame_avg_joints_rel ) / per_frame_std_joints_rel dist_base_pts_to_pert_rhand_joints = (dist_base_pts_to_pert_rhand_joints - per_frame_avg_joints_dists_rel) / per_frame_std_joints_dists_rel ''' Relative positions and distances normalization, strategy 3 ''' ## denoise in a regular representations space ## # if self.denoising_stra == "motion_to_rep": # sampled_rhand_joints_normed = (sampled_rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) ''' Relative positions and distances normalization, strategy 4 ''' # rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints / (self.maxx_rel - self.minn_rel).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(img.device) # dist_base_pts_to_pert_rhand_joints = dist_base_pts_to_pert_rhand_joints / (self.maxx_dists - self.minn_dists).unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).squeeze(-1).to(img.device) ''' Relative positions and distances normalization, strategy 4 ''' input_data = { 'sampled_rhand_joints': sampled_rhand_joints, # 'sampled_rhand_joints': pert_rhand_joints, # 'sampled_rhand_joints': base_pts.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1), 'base_pts': base_pts, 'base_normals': base_normals, 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'rel_base_pts_to_rhand_joints': img, 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, 'pert_rhand_joints': normed_sampled_rhand_joints, # scaled_sampled_rhand_joints # 'pert_rhand_joints': scaled_sampled_rhand_joints, # scaled_sampled_rhand_joints } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update(init_image_avg_std_stats) yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] # normalization strategy for joints and that for the representation values # if 'sampled_base_pts_nearest_obj_pc' in x_start: ambient_xstart_dict = { 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], } if 'e_disp_rel_to_base_along_normals' in x_start: e_disp_xstart_dict = { # x_start['e_disp_rel_to_base_along_normals'], x_start['e_disp_rel_to_baes_vt_normals'], 'per_frame_avg_disp_along_normals': x_start['per_frame_avg_disp_along_normals'], 'per_frame_std_disp_along_normals': x_start['per_frame_std_disp_along_normals'], 'per_frame_avg_disp_vt_normals': x_start['per_frame_avg_disp_vt_normals'], 'per_frame_std_disp_vt_normals': x_start['per_frame_std_disp_vt_normals'], 'e_disp_rel_to_base_along_normals': x_start['e_disp_rel_to_base_along_normals'], 'e_disp_rel_to_baes_vt_normals': x_start['e_disp_rel_to_baes_vt_normals'], } joints_scaling_factor = 5. ''' GET rel and dists ''' # denoising if self.denoising_stra == "rep": # bsz x ws x nnj x nnb x 3 # # avg_jts: 1 x nnj x 3 # std_jts: 1 x nnj x 3 # rhand_joints: bsz x ws x nnj x 3; normalize rhand joitns # exp_rhand_joints = rhand_joints.view(rhand_joints.size(0) * rhand_joints.size(1), rhand_joints.size(2), 3) self.avg_jts = torch.mean(exp_rhand_joints, dim=0, keepdim=True) self.std_jts = torch.std(exp_rhand_joints, dim=0, keepdim=True) normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, t, noise=noise_rhand_joints) ### scale rhand joints ## # rhand joints: bsz x ws x nnj x 3 exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. rhand_joints = rhand_joints - avg_exp_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # avg_pert_normed_rhand_joints = pert_normed_rhand_joints[0].mean(dim=0).mean(dim=0) # avg_pert_rhand_joints = pert_rhand_joints[0].mean(dim=0).mean(dim=0) # print(f"avg_normed_jts: {avg_pert_normed_rhand_joints}, avg_pert_rhand_joints: {avg_pert_rhand_joints}") # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_pert_rhand_joints = pert_rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 -> rel_bae_pts to rhand joints # # bsz x ws x nnj x nnb x 3 ## # avg_rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints[0].mean(dim=0).mean(dim=0)[0] # print(f"Before normalization: avg_rel {avg_rel_base_pts_to_pert_rhand_joints}") dist_rel_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) k_f = 1. # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(rel_base_pts_to_pert_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# pert_rhand_joints_disp = pert_rhand_joints[:, 1:, :, :] - pert_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_pert_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * pert_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_pert_rhand_joints_vt_normal = pert_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_pert_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_pert_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_pert_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_pert_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] ### # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_pert_rhand_joints = (rel_base_pts_to_pert_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # dist_rel_base_pts_to_pert_rhand_joints: bsz x ws x nnj x nnb # --> rel and dists ## dist_rel_base_pts_to_pert_rhand_joints = (dist_rel_base_pts_to_pert_rhand_joints - x_start['per_frame_avg_joints_dists_rel'] ) / x_start['per_frame_std_joints_dists_rel'] dist_base_pts_to_pert_rhand_joints = dist_rel_base_pts_to_pert_rhand_joints rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints # # bsz x ws x nnj x nnb x 3 ## # avg_rel_base_pts_to_pert_rhand_joints = rel_base_pts_to_pert_rhand_joints[0].mean(dim=0).mean(dim=0)[0] # print(f"After normalization: avg_rel {avg_rel_base_pts_to_pert_rhand_joints}") # if noise is None: # rel_noise = th.randn_like(rel_base_pts_to_rhand_joints) # # per_frame_avg_joints_rel = x_start['per_frame_avg_joints_rel'] # # per_frame_std_joints_rel = x_start['per_frame_std_joints_rel'] # std joints rel # # # denorm_rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints * per_frame_std_joints_rel) + per_frame_avg_joints_rel # # denorm_pert_rel_base_pts_to_rhand_joints = # # norm_rel_noise = (rel_noise - ) # ''' stra 1 -> independent noise ''' # # bsz x ws x nnjts x 3 # ## sample perturbed joints ## # # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) # ''' stra 2 -> same noise ''' # rel_noise = rel_noise[:, :, 0, :, :].unsqueeze(2).repeat(1, 1, rel_base_pts_to_rhand_joints.size(2), 1, 1).contiguous() # rel_base_pts_to_pert_rhand_joints = self.q_sample(rel_base_pts_to_rhand_joints, t, noise=rel_noise) # ''' stra 2 -> same noise ''' # # normalization for each framej -> the relative positions and signed distances # # dist_noise = th.randn_like(dist_base_pts_to_rhand_joints) # dist_base_pts_to_pert_rhand_joints = self.q_sample(dist_base_pts_to_rhand_joints, t, noise=dist_noise) elif self.denoising_stra == "motion_to_rep": # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' input_data = { 'base_pts': base_pts.clone(), 'base_normals': base_normals.clone(), 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints.clone(), 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(), 'pert_rhand_joints': pert_normed_rhand_joints, # scaled_rhand_joints, pert_scaled_rhand_joints # 'pert_rhand_joints': pert_scaled_rhand_joints, } if 'sampled_base_pts_nearest_obj_pc' in x_start: input_data.update(ambient_xstart_dict) # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # bsz x ws - 1 x nnj x nnb # input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, } ) # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) if model_kwargs is None: model_kwargs = {} # if noise is None: # noise = th.randn_like(x_start) # x_t = self.q_sample(x_start, t, noise=noise) terms = {} joint_seq_output, model_output, e_model_output = model(input_data, self._scale_timesteps(t).clone()) # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] target = x_start target_rel_base_pts_to_jts = target['rel_base_pts_to_rhand_joints'] target_dist_base_pts_to_jts = target['dist_base_pts_to_rhand_joints'] dec_rel = model_output['dec_rel'] dec_dist = model_output['dec_dist'] dec_rhand_joints = joint_seq_output # e_model_output, dec_e_along_normalss, dec_e_vt_normals # dec_e_along_normals = e_model_output['dec_e_along_normals'] dec_e_vt_normals = e_model_output['dec_e_vt_normals'] # print(f'target_rel_base_pts_to_jts: {target_rel_base_pts_to_jts.size()}, target_dist_base_pts_to_jts: {target_dist_base_pts_to_jts.size()}, dec_rel: {dec_rel.size()}, dec_dist: {dec_dist.size()}') # terms['rot_mse'] = torch.sum( # (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 # ).mean() + torch.mean( # (target_dist_base_pts_to_jts - dec_dist) ** 2, dim=-1 # # ) # rel pred; dist pred; # terms['rel_pred'] = torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ).mean() terms['dist_pred'] = ((target_dist_base_pts_to_jts - dec_dist) ** 2).mean() ### rel pred loss ### rel_pred_loss = torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) dist_pred_loss = ((target_dist_base_pts_to_jts - dec_dist) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) dec_e_along_normals_loss = ((dec_e_along_normals - x_start['e_disp_rel_to_base_along_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) dec_e_vt_normals_loss = ((dec_e_vt_normals - x_start['e_disp_rel_to_baes_vt_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # # joints_pred_loss: bsz x ws x nnj joints_pred_loss = torch.sum( (dec_rhand_joints - normed_rhand_joints) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) # scaled_rhand_joints, pert_scaled_rhand_joints # joints_pred_loss = torch.sum( # (dec_rhand_joints - scaled_rhand_joints) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1) terms['rel_pred_loss'] = rel_pred_loss terms['dist_pred_loss'] = dist_pred_loss terms['dec_e_along_normals_loss'] = dec_e_along_normals_loss terms['dec_e_vt_normals_loss'] = dec_e_vt_normals_loss terms['joints_pred_loss'] = joints_pred_loss terms['rot_mse'] = (torch.sum( (target_rel_base_pts_to_jts - dec_rel) ** 2, dim=-1 ) + (target_dist_base_pts_to_jts - dec_dist) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # x_start['e_disp_rel_to_base_along_normals'], x_start['e_disp_rel_to_baes_vt_normals'], terms['rot_mse'] = terms['rot_mse'] + ((dec_e_along_normals - x_start['e_disp_rel_to_base_along_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) + ((dec_e_vt_normals - x_start['e_disp_rel_to_baes_vt_normals']) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['rot_mse'] = terms['rot_mse'] + joints_pred_loss ### joints pred loss ## # terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_out_in ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class GaussianDiffusionV4: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['input_data']['base_pts'].shape[0] assert t.shape == (B,) input_data = x['input_data'] ## dec_out and out ## ## output dict ## out_dict = model.model.dec_latents_to_joints_with_t(x, input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## self.posterior_variance, self. model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_basejtse: base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] ### q-sampled latent mean here ### basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t ) basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # base_jts_e_feats = out_dict["base_jts_e_feats"] dec_e_along_normals = out_dict["dec_e_along_normals"] dec_e_vt_normals = out_dict["dec_e_vt_normals"] dec_d = out_dict["dec_d"] rel_vel_dec = out_dict["rel_vel_dec"] basejtse_seq_rt_dict = { ### baesjtse seq latents ### "basejtse_seq_latents_mean": basejtse_seq_latents_mean, "basejtse_seq_latents_variance": basejtse_seq_latents_variance, "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, "dec_d": dec_d, "rel_vel_dec": rel_vel_dec } else: basejtse_seq_rt_dict = {} # rt_dict.update(jts_seq_rt_dict) # rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( self, model, x, # psampele # t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, # starting clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} ### basejtsrel rt dict ### ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, # ### decoded output values ### # "joint_seq_output": joint_seq_output, # "basejtsrel_output": basejtsrel_output, # "dec_e_along_normals": dec_e_along_normals, # "dec_e_vt_normals": dec_e_vt_normals, if self.diff_basejtse: ##### ===== Sample for basejtse_seq_latents_sample ===== ##### ### rel_base_pts_outputs mask ### basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # print('const_noise', const_noise) if const_noise: basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) basejtse_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) ) # no noise when t == 0 #### ==== basejtsrel_seq_latents ===== #### basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) #### ==== basejtsrel_seq_latents ===== #### ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## dec_e_vt_normals = out["dec_e_vt_normals"] dec_d = out['dec_d'] rel_vel_dec = out['rel_vel_dec'] basejtse_rt_dict = { "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, "dec_d": dec_d, "rel_vel_dec": rel_vel_dec } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image should not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] vel_obj_pts_to_hand_pts = init_image['vel_obj_pts_to_hand_pts'] obj_pts_disp = init_image['obj_pts_disp'] avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## if 'sampled_base_pts_nearest_obj_pc' in init_image: ambient_init_image = { 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals if self.args.wo_e_normalization: init_image['per_frame_avg_disp_along_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_along_normals']) init_image['per_frame_avg_disp_vt_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_vt_normals']) init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals']) if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel']) init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel']) init_image_avg_std_stats = { 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) # if noise is not None: # img = noise # else: # img = th.randn(*shape, device=device) ### sample progresssive ### # if skip_timesteps and init_image is None: # rhand_joints = th.zeros_like(img) # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") joints_scaling_factor = 5. # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # x_start['per_frame_avg_joints_rel'] = torch # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) # noise_rhand_joints = th.randn_like(normed_rhand_joints) # pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) # # ### scale rhand joints ## # # rhand joints: bsz x ws x nnj x 3 exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) # ## avg exp rhadn joints ## if self.args.jts_sclae_stra == "std": avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True) extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True) elif self.args.jts_sclae_stra == "bbox": maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True)) else: raise ValueError(f"Unrecognized jts_sclae_stra: {self.args.jts_sclae_stra}") rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> # pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) ### Calculate moving related energies ### # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * init_image['per_frame_std_joints_rel'] + init_image['per_frame_avg_joints_rel'] denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) ## l2 real base pts k_f = 1. ## l2 rel base pts to pert rhand joints ## # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # bsz x (ws - 1) x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. ### e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # ## e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ''' Normalization Strategy 4 ''' elif self.denoising_stra == "motion_to_rep": my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] joints_noise = th.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, my_t, noise=joints_noise) pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") # maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) # minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) # avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # rhand_joints = rhand_joints - avg_exp_rhand_joints.unsqueeze(1) # ### scale rhand joints ## # # rhand joints: bsz x ws x nnj x 3 ## rhand joints for exp_rhand_joints ## # exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) # maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) # minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) # avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # # extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True)) # rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) # denoised es # # prersentations --- denoisng--> input_data = { 'base_pts': base_pts, 'base_normals': base_normals, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # 'pert_rhand_joints': pert_normed_rhand_joints, 'pert_rhand_joints': pert_scaled_rhand_joints, 'rhand_joints': rhand_joints, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), 'avg_joints_sequence': avg_joints_sequence, } if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, 'vel_obj_pts_to_hand_pts': vel_obj_pts_to_hand_pts, 'obj_pts_disp': obj_pts_disp } ) # input input_data.update(init_image_avg_std_stats) input_data['rhand_joints'] = rhand_joints # normed my_t = th.tensor([indices[-1]] * shape[0], device=device) # clean_joint_seq_latents = model(input_data, self._scale_timesteps(my_t)) # noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) # # pert_joint_seq_latents: bsz x seq x d # # pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## pert_joint_seq_latents, pert_basejtsrel_seq_latents ## out_dict = model(input_data, self._scale_timesteps(my_t).clone()) dec_in_dict = {} if self.diff_basejtse: ### Sample for perturbed basejtsrel seq latents ### basejtse_seq_latents = out_dict["base_jts_e_feats"] # if 'base_jts_e_feats_mean' in out_dict: basejtse_seq_latents = out_dict['base_jts_e_feats_mean'] noise_basejtse_seq_latents = th.randn_like(basejtse_seq_latents) pert_basejtse_seq_latents = self.q_sample(basejtse_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() if self.args.rnd_noise: dec_in_dict['base_jts_e_feats'] = basejtse_seq_latents else: dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents dec_in_dict['base_jts_e_feats_enc'] = basejtse_seq_latents # dec in dict here # # dec_in_dict = { # "joints_seq_latents": pert_joint_seq_latents, # "rel_base_pts_outputs": pert_basejtsrel_seq_latents, # "base_jts_e_feats": pert_basejtse_seq_latents, # } ### !!! update for input data !!! ### dec_in_dict['input_data'] = input_data # input_data['pert_joint_seq_latents'] = pert_joint_seq_latents ## decoded model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # # p_sample_with_grad sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample # out = sample_fn( model, dec_in_dict, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, const_noise=const_noise, ) # yield out # img = out["sample"] # dec_clean_joint_seq = out["dec_clean_joint_seq"] input_data = {} dec_in_dict = {} # dec_clean_joint_seq = dec_clean_joint_seq * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # latent spapce -> # single step noise ## gaussian diffusion ours ## # basejtsrel_output: bsz x nf x nnj x nnb x 3 --> rel outputs # # diff_basejtse ## basejtse ## if self.diff_basejtse: ## seq latents ## seq latents ## basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"] dec_e_along_normals = out["dec_e_along_normals"] dec_e_vt_normals = out["dec_e_vt_normals"] dec_d = out["dec_d"] rel_vel_dec = out["rel_vel_dec"] dec_e_along_normals = dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals'] dec_e_vt_normals = dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals'] ## model constraints and model impacts from object a to object c ## basejtse_seq_input_dict = { 'e_disp_rel_to_base_along_normals': dec_e_along_normals, 'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, ### vt_normals ### 'rel_vel_dec': rel_vel_dec, 'dec_d': dec_d } # basejts e seq dec in dict # basejtse_seq_dec_in_dict = { "base_jts_e_feats": basejtse_seq_latents_sample, } else: basejtse_seq_input_dict = {} basejtse_seq_dec_in_dict = {} input_data = { 'base_pts': base_pts, 'base_normals': base_normals, 'rhand_joints': rhand_joints, } ## jts seq input_data.update(jts_seq_input_dict) input_data.update(basejtsrel_seq_input_dict) input_data.update(basejtse_seq_input_dict) if 'sampled_base_pts_nearest_obj_pc' in init_image: input_data.update(ambient_init_image) input_data.update(init_image_avg_std_stats) input_data['rhand_joints']= rhand_joints ## input_data ## ### sampled rhand joints ### if 'sampled_rhand_joints' not in input_data: # sampled_rhand_joints --> sampled rhand joints sampled_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # latent spapce -> # single step noise # input_data['sampled_rhand_joints'] = sampled_rhand_joints # sampled_rhand_joints # # sampled_rhand_joints --> sampled_rhand_joints # dec_in_dict.update(jts_seq_dec_in_dict) dec_in_dict.update(basejtse_seq_dec_in_dict) dec_in_dict.update(basejtsrel_seq_dec_in_dict) dec_in_dict['input_data'] = input_data # dec_in_dict = { # "input_data": input_data, # } yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if self.args.train_diff: # set enc to evals # # print(f"Setitng encoders to eval mode") model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] vel_obj_pts_to_hand_pts = x_start["vel_obj_pts_to_hand_pts"] obj_pts_disp = x_start["obj_pts_disp"] dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] ## no diffu # normalization strategy for joints and that for the representation values # if 'sampled_base_pts_nearest_obj_pc' in x_start: ambient_xstart_dict = { 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals if self.args.wo_e_normalization: x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel']) x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel']) ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # x_start['per_frame_avg_joints_rel'] = torch # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # data normalization; # construct statistics, normalize values # joints_scaling_factor = 5. ''' GET rel and dists ''' if self.denoising_stra == "rep": # bsz x ws x nnj x nnb x 3 # # avg_jts: 1 x nnj x 3 # std_jts: 1 x nnj x 3 # rhand joints # rhand_joints: bsz x ws x nnj x 3; normalize rhand joitns # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, t, noise=noise_rhand_joints) # ### scale rhand joints ## # rhand joints: bsz x ws x nnj x 3 ## each joint 1 x 3 -> normalization # exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) if self.args.jts_sclae_stra == "std": avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True) extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True) elif self.args.jts_sclae_stra == "bbox": ### bounding box ### maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints # extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True)) ### bounding box ### else: raise ValueError(f"Unrecognized jts_scale_str: {self.args.jts_sclae_stra}") ## rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # and ### Calculate moving related energies ### # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel'] denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) ## l2 real base pts k_f = 1. ## l2 rel base pts to pert rhand joints ## # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # bsz x (ws - 1) x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. ### e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # ## e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] elif self.denoising_stra == "motion_to_rep": # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' input_data = { 'base_pts': base_pts.clone(), # base pts ### 'base_normals': base_normals.clone(), # base normals ### # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints.clone(), ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## -> 1) encode to the latent space; 2) add noise in the latent space; 3) denoise latent codes; 4) use denoised latent codes for further prediction ### 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(), # 'pert_rhand_joints': pert_normed_rhand_joints, # scaled_rhand_joints, pert_scaled_rhand_joints 'pert_rhand_joints': pert_scaled_rhand_joints, # 'rhand_joints': rhand_joints, 'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ## } if 'sampled_base_pts_nearest_obj_pc' in x_start: input_data.update(ambient_xstart_dict) # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # bsz x ws - 1 x nnj x nnb # # input_data disp_dist = dist_base_pts_to_rhand_joints[:-1] # (ws - 1 ) x nnj x nnb # input_data.update( { # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals ### e_disp_rel_to_base_along_normals: 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, ## clean values # # the denoising space is then transformed to the latent space; noisy inputs -> latent code --> can we really denoise them correctly "obj_pts_disp": obj_pts_disp, 'vel_obj_pts_to_hand_pts': vel_obj_pts_to_hand_pts, 'disp_dist': disp_dist } ) # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # if self.args.train_diff: with torch.no_grad(): out_dict = model(input_data, self._scale_timesteps(t).clone()) else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_basejtse: ### Sample for perturbed basejtsrel seq latents ### basejtse_seq_latents = out_dict["base_jts_e_feats"] noise_basejtse_seq_latents = th.randn_like(basejtse_seq_latents) pert_basejtse_seq_latents = self.q_sample(basejtse_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() ### Sample for perturbed basejtsrel seq latents ### dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents dec_in_dict['base_jts_e_feats_enc'] = basejtse_seq_latents if self.args.kl_weights > 0. and "base_jts_e_feats_mean" in out_dict and not self.args.train_diff: # clean_joint_seq_latents: seq_len x bs x d # log_p_base_jts_e_seq = model_util.standard_normal_logprob(basejtse_seq_latents) log_p_base_jts_e_seq = log_p_base_jts_e_seq.permute(1, 0, 2).contiguous() # log_p_base_jts_e_seq = log_p_base_jts_e_seq.sum(dim=-1).mean(dim=-1) # log_p_joints_seq entropy_base_jts_e_seq = model_util.gaussian_entropy(out_dict['base_jts_e_feats_logvar'].permute(1, 2, 0)).mean(dim=-1) loss_prior_base_jts_e_seq = (- log_p_base_jts_e_seq - entropy_base_jts_e_seq) KL_loss += loss_prior_base_jts_e_seq # dec_in_dict = { # "joints_seq_latents": pert_joint_seq_latents, # "rel_base_pts_outputs": pert_basejtsrel_seq_latents, # "base_jts_e_feats": pert_basejtse_seq_latents, # } # # dec_clean_joint_seq: bsz x ws x nnj x 3 # # dec_clena_seq_latents: seq x bs x d # dec_clean_joint_seq, dec_clena_seq_latents = model.model.dec_latents_to_joints_with_t(pert_joint_seq_latents, self._scale_timesteps(t).clone()) # dec_clean_joint_seq: bsz x ws x nnj x 3 # dec_clena_seq_latents: seq x bs x d dec_out_dict = model.model.dec_latents_to_joints_with_t(dec_in_dict, input_data, self._scale_timesteps(t).clone()) terms['rot_mse'] = 0. if self.diff_basejtse: dec_base_jts_e_feats = dec_out_dict['base_jts_e_feats'] dec_e_along_normals = dec_out_dict['dec_e_along_normals'] dec_e_vt_normals = dec_out_dict['dec_e_vt_normals'] dec_d = dec_out_dict['dec_d'] rel_vel_dec = dec_out_dict['rel_vel_dec'] # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 # basejtse_along_normals_pred_loss = torch.sum( (e_disp_rel_to_base_along_normals.unsqueeze(-1) - dec_e_along_normals.unsqueeze(-1)) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) basejtse_vt_normals_pred_loss = torch.sum( (e_disp_rel_to_baes_vt_normals.unsqueeze(-1) - dec_e_vt_normals.unsqueeze(-1)) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) d_pred_loss = torch.sum( (dec_d.unsqueeze(-1) - disp_dist.unsqueeze(-1)) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) rel_vel_pred_loss = torch.sum( (rel_vel_dec.unsqueeze(-1) - vel_obj_pts_to_hand_pts.unsqueeze(-1)) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # basejtse_latent_denoising_loss = (torch.sum( # (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / basejtse_seq_latents.size(-1)).mean(dim=-1) if self.args.pred_diff_noise: # noise_joint_seq_latents basejtse_latent_denoising_loss = (torch.sum( (basejtse_seq_latents.permute(1, 0, 2).contiguous() - noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 ) / basejtse_seq_latents.size(-1)).mean(dim=-1) else: basejtse_latent_denoising_loss = (torch.sum( (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1 ) / basejtse_seq_latents.size(-1)).mean(dim=-1) # find out kong # if self.args.train_enc: basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss) if self.args.train_diff: # train_diff # no basejtse denoising losses ## # basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss) basejtse_along_normals_pred_loss = torch.zeros_like(basejtse_along_normals_pred_loss) basejtse_vt_normals_pred_loss = torch.zeros_like(basejtse_vt_normals_pred_loss) d_pred_loss = torch.zeros_like(d_pred_loss) rel_vel_pred_loss = torch.zeros_like(rel_vel_pred_loss) terms['basejtse_along_normals_pred_loss'] = basejtse_along_normals_pred_loss terms['basejtse_vt_normals_pred_loss'] = basejtse_vt_normals_pred_loss terms['basejtse_latent_denoising_loss'] = basejtse_latent_denoising_loss terms['d_pred_loss'] = d_pred_loss terms['rel_vel_pred_loss'] = rel_vel_pred_loss # terms['rot_mse'] += basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss terms['rot_mse'] += basejtse_along_normals_pred_loss * self.args.basejtse_along_normal_loss_coeff + basejtse_vt_normals_pred_loss * self.args.basejtse_vt_normal_loss_coeff + basejtse_latent_denoising_loss + d_pred_loss * self.basejtse_along_normal_loss_coeff + rel_vel_pred_loss * self.basejtse_along_normal_loss_coeff if self.args.kl_weights > 0. and not self.args.train_diff: terms['KL_loss'] = KL_loss terms['rot_mse'] += KL_loss * self.args.kl_weights # sv_inter_dict = { # 'dec_joints': dec_clean_joint_seq.detach().cpu().numpy(), # 'rhand_joints': rhand_joints.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } # # sv_inter_dict_fn = os.path.join(args.save_dir, ) ### construct final loss ### # terms['rot_mse'] = jts_pred_loss * 20 + jts_latent_denoising_loss + basejtsrel_pred_loss * 20 + basejtsrel_latent_denoising_loss + basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss # ### === only use joints-only losses === ### # terms['rot_mse'] = jts_pred_loss * 20 + jts_latent_denoising_loss ### === only use base-jts-rel losses === ### # terms['rot_mse'] = basejtsrel_pred_loss * 20 + basejtsrel_latent_denoising_loss ### === only use base-jts-e losses === ### # terms['rot_mse'] = basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss # terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class VarianceSchedule(torch.nn.Module): def __init__(self, num_steps, betas): super().__init__() # assert mode in ('linear', ) self.num_steps = num_steps ## variance schedule ## # self.beta_1 = beta_1 # self.beta_T = beta_T # self.mode = mode ## beta_1 = 1e-4 -> very small variance ## beta_T = 0.02 -> large variance # if mode == 'linear': # betas = torch.linspace(beta_1, beta_T, steps=num_steps) print(f"betas: {betas.size()}, betas_0: {betas[0]}, betas_T: {betas[-1]}") # betas = torch.cat([torch.zeros([1]), betas], dim=0) # zero variance, Padding betas --> betas = betas.clone() alphas = 1 - betas log_alphas = torch.log(alphas) for i in range(1, log_alphas.size(0)): # 1 to T ## 1 to T ## variacne schedual # log_alphas[i] += log_alphas[i - 1] alpha_bars = log_alphas.exp() sigmas_flex = torch.sqrt(betas) sigmas_inflex = torch.zeros_like(sigmas_flex) for i in range(1, sigmas_flex.size(0)): # sigma inflex sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] sigmas_inflex = torch.sqrt(sigmas_inflex) self.register_buffer('betas', betas) self.register_buffer('alphas', alphas) self.register_buffer('alpha_bars', alpha_bars) self.register_buffer('sigmas_flex', sigmas_flex) self.register_buffer('sigmas_inflex', sigmas_inflex) def uniform_sample_t(self, batch_size): ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) return ts.tolist() def get_sigmas(self, t, flexibility): assert 0 <= flexibility and flexibility <= 1 sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility) return sigmas class GaussianDiffusionV5: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64)) # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( # posterior mean coefs # (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. ''' Load statistics ''' # avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" # std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" # avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" # std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" # avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) # std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) # avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) # std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) # ## self.avg_joints_rel, self.std_joints_rel # ## self.avg_joints_dists, self.std_joints_dists # self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() # self.std_joints_rel = torch.from_numpy(std_joints_rel).float() # self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() # self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # # self.maxx_rel, minn_rel, maxx_dists, minn_dists # # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" # rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() # maxx_rel = rel_dists_stats['maxx_rel'] # minn_rel = rel_dists_stats['minn_rel'] # maxx_dists = rel_dists_stats['maxx_dists'] # minn_dists = rel_dists_stats['minn_dists'] # self.maxx_rel = torch.from_numpy(maxx_rel).float() # self.minn_rel = torch.from_numpy(minn_rel).float() # self.maxx_dists = torch.from_numpy(maxx_dists).float() # self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' # avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" # std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" # avg_jts = np.load(avg_jts_fn, allow_pickle=True) # std_jts = np.load(std_jts_fn, allow_pickle=True) # # self.avg_jts, self.std_jts # # self.avg_jts = torch.from_numpy(avg_jts).float() # self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ### variance xxx noise ### ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped # phy project pred joints; phy predict joints here # def phy_projct_pred_joints(self, pred_joints, base_pts, base_normals): # pred_joints: bsz x nf x nn_jts x 3 # # pred joints # # base_pts: bsz x nn_base_pts x 3 # # base_normals: bsz x nn_base_pts x 3 # nf = pred_joints.size(1) if not self.args.use_arti_obj: base_pts_exp = base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() else: base_pts_exp = base_pts.clone() base_normals_exp = base_normals.clone() nearest_pred_joints_to_base_pts = torch.sum( (pred_joints.unsqueeze(-2) - base_pts_exp.unsqueeze(2)) ** 2, dim=-1 ) nearest_dist, nearest_base_pts_idxes = torch.min(nearest_pred_joints_to_base_pts, dim=-1) # bsz x nf x nn_jts nearest_dist = torch.sqrt(nearest_dist) nearest_base_pts = model_util.batched_index_select_ours(base_pts_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 nearest_base_normals = model_util.batched_index_select_ours(base_normals_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 # jts_to_base_pts = pred_joints - nearest_base_pts # from base pts to pred joints # dot_rel_with_normals = torch.sum( jts_to_base_pts * nearest_base_normals, dim=-1 # bsz x nf x nn_jts --> joints inside of the object # ) jts_proj_dir = torch.zeros_like(nearest_base_pts) # bsz x nf x nn_jts x 3 # jts_proj_dir[dot_rel_with_normals < 0.] = jts_to_base_pts[dot_rel_with_normals < 0.] # bsz x nf x nn_jts x 3 # return jts_proj_dir # bsz x nf x nn_jts x 3 # returned gradients # # # the full physical world here? ## def p_mean_variance_cond( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ # p_mean_varaince # if model_kwargs is None: model_kwargs = {} B = x['base_pts'].shape[0] assert t.shape == (B,) # print(f"t_shape: {t.shape}", "base_pts", x['base_pts'].size()) input_data = x out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} real_basejtsrel_seq_rt_dict = {} basejtsrel_seq_rt_dict = {} realbasejtsrel_to_joints_rt_dict = {} model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_realbasejtsrel and self.diff_basejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output # print(f"basejtsrel_output: {basejtsrel_output.size()}") # if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) # t_item = t[0].item() # alpha = self.var_sched.alphas[t_item] # alpha_bar = self.var_sched.alpha_bars[t_item] # sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # c0 = 1.0 / torch.sqrt(alpha) # c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # beta = self.var_sched.betas[[t[0].item()] * bsz] # z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) # basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta # else: # basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # combine those two things # if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: # add noise onjts # if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints # jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # jts_fr_basepts = jts_fr_basepts.mean(dim=-2) jts_fr_basepts = pert_rel_base_pts_outputs # pert # score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 score_jts_fr_basepts = real_dec_basejtsrel[..., self.args.sel_basepts_idx, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] # combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts)[..., -5:, :] # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + score_jts_fr_basepts[..., -5:, :] * 0.3 # combined_socre = combined_socre * 0.2 + score_jts_fr_basepts * 0.8 combined_socre = combined_socre * 0.1 + score_jts_fr_basepts * 0.9 # combined_socre = combined_socre * 0.05 + score_jts_fr_basepts * 0.95 # combined_socre = combined_socre * 0.5 + score_jts_fr_basepts * 0.5 # combined_socre = combined_socre # combined_socre = score_jts_fr_basepts else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts combined_socre = score_jts # not cmb finger # # combined_socre = score_jts_fr_basepts # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, score_jts_fr_basepts: {score_jts_fr_basepts.size()}, combined_socre: {combined_socre.size()}, score_jts: {score_jts.size()}") if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. # use_var_sched -> # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) # print(f"dec_jts_fr_basepts: {dec_jts_fr_basepts.size()}, normed_base_pts: ", x['normed_base_pts'].size(), "real_dec_basejtsrel:", real_dec_basejtsrel.size()) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, denormed_rel_base_pts_to_rhand_joints: {denormed_rel_base_pts_to_rhand_joints.size()}, jts_fr_basepts: {jts_fr_basepts.size()}") elif self.args.add_noise_onjts_single: # add noise on single joint if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] socre_jts_fr_basepts = dec_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # strategy 1 --> conditioning # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts)[..., -5:, :] # strategy 2 --> linear interpolation # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + socre_jts_fr_basepts[..., -5:, :] * 0.3 combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.5 + socre_jts_fr_basepts[..., -5:, :] * 0.5 else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = socre_jts_fr_basepts if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(combined_socre) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (pert_rel_base_pts_outputs - c1 * combined_socre) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: dec_jts_fr_basepts = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) else: # raise ValueError(f"Add noise directly --- not implemented yet") # # input_data # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints if not self.args.use_arti_obj: jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts_fr_basepts if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_jts_fr_basepts } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints raise ValueError(f"Trian enc --- Not implemented yet") pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) # else: # real_basejtsrel_seq_rt_dict = {} # basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel and self.args.diff_realbasejtsrel_to_joints: pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output_from_rel'] # joints offset output # score_jts_fr_rel = dec_joints_offset_output # # pert joints offset sequence # # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_rel # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_rel alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * score_jts) + sigma * z # theta ### realjtsrel_to_joints and joints only ## realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_joints_offset_output } # else: # realbasejtsrel_to_joints_rt_dict = {} # basejtsrel_seq_rt_dict = {} # rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) # rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['base_pts'].shape[0] assert t.shape == (B,) input_data = x ## dec_out and out ## ## output dict ## out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## posterior_log_variance_clipped, posterior_variance ## model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # pmean variance if self.diff_jts: # x_t ## joints seq latents ## pert_joints_seq_latents = x['joints_seq_latents'] # x_t pred_clean_joint_seq_latents = out_dict["joints_seq_latents"] ## if self.args.pred_diff_noise: ## eps -> estimated noises ## t > for added joints latents ## pred_clean_joint_seq_latents = self._predict_xstart_from_eps(pert_joints_seq_latents.permute(1, 0, 2), t=t, eps=pred_clean_joint_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # seq x bs x d # # minn_pert_joints_seq_latents, _ = torch.min(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # maxx_pert_joints_seq_latents, _ = torch.max(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # print(f"pred minn_pert_joints_latents: {minn_pert_joints_seq_latents[:10]}, pred maxx_pert_joints_seq_latents: {maxx_pert_joints_seq_latents[:10]}") ## out_dict["joint_seq_output"] = model.model.dec_jts_only_fr_latents(pred_clean_joint_seq_latents)["joint_seq_output"] ## joints seq latents mean # # pred_clean_joint_seq_latents = pert_joints_seq_latents ## joints seq latents mean # joints_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_clean_joint_seq_latents.permute(1, 0, 2), x_t=pert_joints_seq_latents.permute(1, 0, 2), t=t ) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_mean = joints_seq_latents_mean.permute(1, 0, 2) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_variance = _extract_into_tensor(model_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joints_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # joint seq output # # joint seq output # joint_seq_output = out_dict["joint_seq_output"] jts_seq_rt_dict = { ### joints seq latents ### "joints_seq_latents_mean": joints_seq_latents_mean, "joints_seq_latents_variance": joints_seq_latents_variance, "joints_seq_latents_log_variance": joints_seq_latents_log_variance, ### decoded output values ### "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} # /data1/sim/mdm/save/predoffset_stdscale_bsz_10_pred_diff_realbaesjtsrel_nonorm_std_for_norm_train_enc_with_diff_latents_prediffnoise_none_norm_rel_rel_to_jts_/model000007000.pt if self.args.diff_realbasejtsrel_to_joints: pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output'] # joints offset output # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> beta, z c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * dec_joints_offset_output) + sigma * z # theta ## use the predicted latents and pert_latents for the seq latents prediction ## dec_joints_offset_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=dec_joints_offset_output, x_t=pert_joints_offset_output, t=t ) ## from model_variance to basejtsrel_seq_latents ### dec_joints_offset_output_variance = _extract_into_tensor(model_variance, t, dec_joints_offset_output.shape) dec_joints_offset_output_log_variance = _extract_into_tensor(model_log_variance, t, dec_joints_offset_output.shape) realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) # jts_fr_basepts = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) dec_jts_fr_basepts = real_dec_basejtsrel # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # ## use same noise for rep ## use noise for rep ## ### use_same_noise_for_rep --> use same noise for rep ### if self.args.use_same_noise_for_rep: # # convert them to the strategy of using single noise ## if self.args.sel_basepts_idx >= 0: # real dec base jts rel # dec_jts_fr_basepts = real_dec_basejtsrel[:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1] else: dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2, keepdim=True) # dec noise # # [:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1] # from noise and x_t to x_start; # a projection strategy for x_start; # to noise # and we want to adjust nosie # if self.args.phy_guided_sampling and t[0].item() < 1: # # phy_guided_sampling # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 # if self.args.sel_basepts_idx >= 0: # pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # # else: # pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) # joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 # x_start_projed = pred_dec_jts - joints_proj_dir # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # # dec_ratio = 0.95 # dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + dec_jts_fr_basepts_projed * (1. - dec_ratio) if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_jts_fr_basepts) if t_item > 0 else torch.zeros_like(dec_jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # # real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta ## dec jts fr base pts ## # dec jts fr # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) if self.args.phy_guided_sampling and t[0].item() < 10: # phy_guided_sampling # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 pred_dec_jts = dec_jts_fr_basepts.clone() if self.args.sel_basepts_idx >= 0: pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # else: pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 x_start_projed = pred_dec_jts - joints_proj_dir # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # dec_ratio = 0. dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + x_start_projed.unsqueeze(-2) * (1. - dec_ratio) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(2) elif self.args.add_noise_onjts_single: if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) # dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) # real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) else: if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(pert_rel_base_pts_to_rhand_joints) if t_item > 0 else torch.zeros_like(pert_rel_base_pts_to_rhand_joints) # z = torch.zeros_like(pert_rel_base_pts_to_rhand_joints) real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta # dec_jts_fr_basepts = real_dec_basejtsrel + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # get dec_jts fr basepts # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # repeated basepts # real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # real dec # else: # x_{t-1} real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) ## use the predicted latents and pert_latents for the seq latents prediction ## real_dec_basejtsrel_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=real_dec_basejtsrel, x_t=pert_rel_base_pts_to_rhand_joints, t=t ) ## from model_variance to basejtsrel_seq_latents ### real_dec_basejtsrel_variance = _extract_into_tensor(model_variance, t, real_dec_basejtsrel.shape) real_dec_basejtsrel_log_variance = _extract_into_tensor(model_log_variance, t, real_dec_basejtsrel.shape) # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, "real_dec_basejtsrel_mean": real_dec_basejtsrel_mean, "real_dec_basejtsrel_variance": real_dec_basejtsrel_variance, "real_dec_basejtsrel_log_variance": real_dec_basejtsrel_log_variance, } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) else: real_basejtsrel_seq_rt_dict = {} # else: # x_{t-1} # real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) if self.diff_basejtsrel: if 'basejtsrel_output' in out_dict: pert_rel_base_pts_outputs = x['pert_rel_base_pts_to_rhand_joints'] # rel base pts outputs # pert_avg_joints_sequence = x['pert_avg_joints_sequence'] basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous() avg_jts_outputs = out_dict['avg_jts_outputs'] # if pert_rel_base_pts_outputs.size(0) == 1: # pert_rel_base_pts_outputs = pert_rel_base_pts_outputs.repeat(pred_basejtsrel_seq_latents.size(0), 1, 1) if self.args.pred_diff_noise: ## eps -> estimated-noises basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) avg_jts_outputs = self._predict_xstart_from_eps(pert_avg_joints_sequence, t=t, eps=avg_jts_outputs) # out_dict.update( # # model.model.dec_basejtsrel_only_fr_latents(pred_basejtsrel_seq_latents, x['input_data']) # ) ## use the predicted latents and pert_latents for the seq latents prediction ## basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t ) avg_jts_outputs_mean, _, _ = self.q_posterior_mean_variance( x_start=avg_jts_outputs, x_t=pert_avg_joints_sequence, t=t ) # basejtsrel_seq_latents_mean = basejtsrel_seq_latents_mean.permute(1, 0, 2) ## from model_variance to basejtsrel_seq_latents ### basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape) basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape) ## from model_variance to basejtsrel_seq_latents ### avg_jts_outputs_variance = _extract_into_tensor(model_variance, t, avg_jts_outputs_mean.shape) avg_jts_outputs_log_variance = _extract_into_tensor(model_log_variance, t, avg_jts_outputs_mean.shape) else: pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # basejtsrel_output = out_dict['joints_offset_output'] # print(f"pert_rel_base_pts_outputs: {pert_rel_base_pts_outputs.size()}, basejtsrel_output: {basejtsrel_output.size()}") if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = basejtsrel_output.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # x_t = traj[t] # beta = self.var_sched.betas[[t[0].item()] * bsz] # if mask is not None: # x_t = x_t * mask # e_theta = self.net(x_t, beta=beta, context=context) z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta else: basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) if self.args.phy_guided_sampling and t[0].item() < 200: # phy_guided_sampling # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 pred_dec_jts = basejtsrel_output.clone() # if self.args.sel_basepts_idx >= 0: # pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # # else: # pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 x_start_projed = pred_dec_jts - joints_proj_dir # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # dec_ratio = 0. basejtsrel_output = basejtsrel_output * dec_ratio + x_start_projed * (1. - dec_ratio) ## use the predicted latents and pert_latents for the seq latents prediction ## # basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior # x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t # ) # ## from model_variance to basejtsrel_seq_latents ### # basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape) # basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape) # basejtsrel_output = out_dict["basejtsrel_output"] # print(f"basejtsrel_output: {basejtsrel_output.size()}") basejtsrel_seq_rt_dict = { ### basejtsrel seq latents ### # "avg_jts_outputs": avg_jts_outputs, # "basejtsrel_output_variance": basejtsrel_output_variance, # "basejtsrel_output_log_variance": basejtsrel_output_log_variance, # # "avg_jts_outputs_variance": avg_jts_outputs_variance, # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance, "basejtsrel_output": basejtsrel_output, } else: basejtsrel_seq_rt_dict = {} if self.diff_basejtse: dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] pert_e_along_normals = x['pert_e_disp_rel_to_base_along_normals'] pert_e_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # basejtsrel_output = out_dict['joints_offset_output'] if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = dec_e_along_normals.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_e_along_normals) if t_item > 0 else torch.zeros_like(dec_e_along_normals) dec_e_along_normals = c0 * (pert_e_along_normals - c1 * dec_e_along_normals) + sigma * z # theta z_vt_normals = torch.randn_like(dec_e_vt_normals) if t_item > 0 else torch.zeros_like(dec_e_vt_normals) dec_e_vt_normals = c0 * (pert_e_vt_normals - c1 * dec_e_vt_normals) + sigma * z_vt_normals # theta else: dec_e_along_normals = self._predict_xstart_from_eps(pert_e_along_normals, t=t, eps=dec_e_along_normals) dec_e_vt_normals = self._predict_xstart_from_eps(pert_e_vt_normals, t=t, eps=dec_e_vt_normals) # base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### # pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] # ### q-sampled latent mean here ### # basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( # x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t # ) # basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) # basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # # base_jts_e_feats = out_dict["base_jts_e_feats"] # dec_e_along_normals = out_dict["dec_e_along_normals"] # dec_e_vt_normals = out_dict["dec_e_vt_normals"] basejtse_seq_rt_dict = { ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_seq_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). # """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def judge_activated(self, target_setting): if target_setting: return 1 else: return 0 def p_sample( ## p sample ## self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ # p_sample for the p_ample # Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. # gaussian diffusion # :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ multi_activated = ( self.judge_activated(self.diff_jts) + self.judge_activated(self.args.diff_realbasejtsrel_to_joints) + self.judge_activated(self.diff_realbasejtsrel) + self.judge_activated(self.diff_basejtsrel) + self.judge_activated(self.diff_basejtse) ) > 1.5 if multi_activated: # print(f"Multiple settings activated! Using combined sampling...") p_mena_variance_fn = self.p_mean_variance_cond # p_mean else: # print(f"Single setting activated! Using single sampling...") p_mena_variance_fn = self.p_mean_variance out = p_mena_variance_fn( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} if self.diff_jts: # bsz x ws x nnj x nnb x 3 # joints_seq_latents_noise = th.randn_like(x['joints_seq_latents']) # print('const_noise', const_noise) # seq x bsz x latent_dim # if const_noise: print(f"joints latents hape, ", x['joints_seq_latents'].shape) joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1) # joints_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1))) ) # no noise when t == 0 # bsz x nseq x dim # #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... # joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2) # nseq x bsz x dim # joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # #### ==== joints_seq_latents ===== #### joint_seq_output = out["joint_seq_output"] jts_seq_rt_dict = { "joints_seq_latents_sample": joints_seq_latents_sample, "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.args.diff_realbasejtsrel_to_joints: ## args.pred to joints # dec_joints_offset_output = realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': out['dec_joints_offset_output'] } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: if self.args.train_enc or ( self.args.pred_diff_noise and self.args.use_var_sched): real_dec_basejtsrel = out['real_dec_basejtsrel'] else: real_dec_basejtsrel_noise = th.randn_like(out['real_dec_basejtsrel']) if const_noise: real_dec_basejtsrel_noise = real_dec_basejtsrel_noise[[0]].repeat(out['real_dec_basejtsrel'].shape[0], 1, 1, 1, 1) real_dec_basejtsrel_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['real_dec_basejtsrel'].shape) - 1))) ) real_dec_basejtsrel = out["real_dec_basejtsrel_mean"] + real_dec_basejtsrel_nonzero_mask * th.exp(0.5 * out["real_dec_basejtsrel_log_variance"]) * real_dec_basejtsrel_noise real_basejtsrel_rt_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, } if self.args.train_enc: real_basejtsrel_rt_dict['dec_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] else: real_basejtsrel_rt_dict = {} if self.diff_basejtsrel: # baseptse # if self.args.pred_diff_noise and self.args.use_var_sched: basejtsrel_seq_latents_sample = out['basejtsrel_output'] else: ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### ### rel_base_pts_outputs mask ### basejtsrel_seq_latents_noise = th.randn_like(out['basejtsrel_output']) if const_noise: ## seq latents noise ## basejtsrel_seq_latents_noise = basejtsrel_seq_latents_noise[[0]].repeat(out['basejtsrel_output'].shape[0], 1, 1, 1, 1) basejtsrel_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['basejtsrel_output'].shape) - 1))) ) # no noise when t == 0 #### ==== basejtsrel_seq_latents ===== #### ## sample latent codes -> denoise latent codes basejtsrel_seq_latents_sample = out["basejtsrel_output"] + basejtsrel_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtsrel_output_log_variance"]) * basejtsrel_seq_latents_noise # basejtsrel_seq_latents_sample = basejtsrel_seq_latents_sample.permute(1, 0, 2) #### ==== basejtsrel_seq_latents ===== #### ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### basejtsrel_rt_dict = { "basejtsrel_seq_latents_sample": basejtsrel_seq_latents_sample, # "avg_jts_outputs_sample": avg_jts_outputs_sample, } else: basejtsrel_rt_dict = {} if self.diff_basejtse: # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### # ### rel_base_pts_outputs mask ### # basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # # print('const_noise', const_noise) # if const_noise: # basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) # basejtse_seq_latents_nonzero_mask = ( # (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) # ) # no noise when t == 0 # #### ==== basejtsrel_seq_latents ===== #### # basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) # basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) # #### ==== basejtsrel_seq_latents ===== #### # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## # dec_e_vt_normals = out["dec_e_vt_normals"] basejtse_rt_dict = { # "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) rt_dict.update(real_basejtsrel_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final # score # socre p_sample_loop_progressive # def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # """ # p_sample loop progressive # Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] # rhand_joints = init_image['gt_rhand_joints'] if self.args.use_anchors: # rhand_joints: bsz x nf x nn_anchors x 3 # rhand_joints = init_image['pert_rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand # # rhand_joints = rhand_joints - ## vage for whether this model can work ### # avg_joints_sequence = std_joints_sequence = torch.std(rhand_joints, dim=1) avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## ## if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # ws x 1 x 3 # std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # if self.args.jts_sclae_stra == "std": # and only use the latents # # joints_offset_sequence = joints_offset_sequence / std_joints_sequence if not self.args.use_arti_obj: normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 else: # base_pts: bsz x nf x nnb x 3 # normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # # joints_offset_sequence_ori = joints_offset_sequence.clone() # rhand_joints_ori = rhand_joints.clone() # jts scale stra # jts scale strategies ## # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) if self.args.jts_sclae_stra == "std": joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) if not self.args.use_arti_obj: normed_base_pts = normed_base_pts / std_joints_sequence else: normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # if 'sampled_base_pts_nearest_obj_pc' in init_image: # ambient_init_image = { # 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], # 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], # } ####### E ####### # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # if self.args.wo_e_normalization: # init_image['per_frame_avg_disp_along_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_along_normals']) # init_image['per_frame_avg_disp_vt_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_vt_normals']) # init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) # init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals']) ####### E ####### if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel']) init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel']) init_image_avg_std_stats = { 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) ### without e normalization ### # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") # joints_scaling_factor = 5. # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## # init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # # x_start['per_frame_avg_joints_rel'] = torch # # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] if not self.args.use_arti_obj: ### rel base pts to rhand joints #joints offset sequence # joints offset sequence # joints # joints_offset_sequence - normed_base_pts rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints # else: # rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # bsz x nf x nn_joints x nn_base_pts x 3 # maxx_rel_base_pts_to_rhand_joints, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) minn_rel_base_pts_to_rhand_joints, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) print(f"maxx_rel_base_pts_to_rhand_joints: {maxx_rel_base_pts_to_rhand_joints}, minn_rel_base_pts_to_rhand_joints: {minn_rel_base_pts_to_rhand_joints}") if self.args.real_basejtsrel_norm_stra == "mean": rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # #### rel_base_pts_to_rhand_joints # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True)) elif self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # rel exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) # noise_rhand_joints = th.randn_like(normed_rhand_joints) # pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) # rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) # scaled_rhand_joints = rhand_joints * joints_scaling_factor # noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) # pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> # pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) ####### E ####### # if not self.args.use_arti_obj: # # rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 # # denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # else: # denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(2) # # ### Calculate moving related energies ### # # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## # # denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * init_image['per_frame_std_joints_rel'] + init_image['per_frame_avg_joints_rel'] # # denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # # denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( # # denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # # ) ## l2 real base pts # k_f = 1. ## l2 rel base pts to pert rhand joints ## # # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) # ### att_forces ## # att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # # bsz x (ws - 1) x nnj x nnb # # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # # bsz x (ws - 1) x nnj x 3 --> displacements s# # # denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # denormed_rhand_joints_disp = rhand_joints[:, 1:, :, :] - rhand_joints[:, :-1, :, :] # # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # # if not self.args.use_arti_obj: # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) # else: # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals[:, :-1].unsqueeze(2) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals[:, :-1].unsqueeze(2) # bsz x nf x nn_joints x nn_base_pts x 3 --> rel # dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( # rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 # )) # k_a = 1. # k_b = 1. # ### # e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # # (ws - 1) x nnj x nnb # -> dist vt normals # ## # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal ####### E ####### # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) ### e_disp_rel_to_base_along_normals ### ---> # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ''' Normalization Strategy 4 ''' else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") # my_t = th.tensor([indices[-1]] * shape[0], device=device) my_t = th.tensor([indices[0]] * shape[0], device=device) # clean_joint_seq_latents = model(input_data, self._scale_timesteps(my_t)) # noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) # # pert_joint_seq_latents: bsz x seq x d # # pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() ####### E ####### # if not self.args.wo_e_normalization and self.args.e_normalization_stra == "cent": # bsz = e_disp_rel_to_base_along_normals.size(0) # nf, nnj, nnb = e_disp_rel_to_base_along_normals.size()[1:] # high dimensional # # the max value and min value of all values # #bs z x nnf x nnj x nnb --> for the along normals values and vt normals values ## # maxx_e_disp_rel_to_base_along_normals, _ = torch.max(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_along_normals , _ = torch.min(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_vt_normals , _ = torch.max(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_vt_normals , _ = torch.min(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_along_normals = maxx_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_along_normals = minn_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # maxx_e_disp_rel_to_base_vt_normals = maxx_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_vt_normals = minn_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # init_image['per_frame_avg_disp_along_normals'] = (maxx_e_disp_rel_to_base_along_normals + minn_e_disp_rel_to_base_along_normals) / 2. # init_image['per_frame_avg_disp_vt_normals'] = (maxx_e_disp_rel_to_base_vt_normals + minn_e_disp_rel_to_base_vt_normals) / 2. # init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) # init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals'] ) # # normalize ## # base along normals # # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ####### E ####### #### add noise onjts #### to base along normal # rigid objects -> moving; global pose # hwo about we do not add those canonicalization? # and we only need correct contacts? to model # attaction forces? attraction forces # distances? # distances? # k_f = e^{-k\cdot \Vert v_o - v_h\Vert}; --> the proximity value between each pair of points; --> points on the object one object denoising targets --> the distance from hand joint to the object surface; # distance values --> distance values # manipulate the object --> add forces to the object # # manipulate the object --> add forces to the object # # map # a simple case -> map joint points to the object points -> denoise relative positions; realtive positions; joint trajectory; # values that describe the consistency between moving (value negative propotional to distances) * exp(\Vert x_o - x_h\Vert_2) ---> to describe the moving consistency between the hand and the object. # contact map -> or a generatlized contact map -> # add_noise_onjts, add_noise_onjts_single # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints ## pred jts # pred jts # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # space spatial # rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints ) if self.args.add_noise_onjts: ### add noise on joints ### if self.args.use_same_noise_for_rep: ### use same noise for rep ### noise_joints_offset_output = torch.randn_like(joints_offset_sequence) pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_sequence, my_t, noise_joints_offset_output ) if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) else: if not self.args.use_arti_obj: joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints ) pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) else: joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints ) pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(2) elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single # noise_joints_offset_output = torch.randn_like(joints_offset_sequence) pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_sequence, my_t, noise_joints_offset_output ) if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) if self.args.train_enc: pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints # ) # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) # pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... # avg_joints_sequence, my_t, noise_avg_joints_sequence # ) # joints_offset_sequence # joints offset sequence ## noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) print(f"my_t: {my_t}") pert_joints_offset_sequence = self.q_sample( joints_offset_sequence, my_t, noise_joints_offset_sequence ) if self.args.add_noise_onjts_single: noise_joints_offset_sequence = noise_joints_offset_output pert_joints_offset_sequence = pert_joints_offset_output if not self.args.use_arti_obj: if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel): pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) if self.args.finetune_with_cond: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") else: if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel): pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) if self.args.finetune_with_cond: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") # sv_pert_dict = { # 'joints_offset_sequence': joints_offset_sequence.detach().cpu().numpy(), # 'pert_joints_offset_sequence': pert_joints_offset_sequence.detach().cpu().numpy(), # 'noise_joints_offset_sequence': noise_joints_offset_sequence.detach().cpu().numpy(), # 'joints_offset_sequence_ori': joints_offset_sequence_ori.detach().cpu().numpy(), # 'rhand_joints_ori': rhand_joints.detach().cpu().numpy(), # } # sv_pert_dict_fn = "tot_pert_jts_sequence_dict.npy" # this file @!!!!! # np.save(sv_pert_dict_fn, sv_pert_dict) # print(f"pert data saved to {sv_pert_dict_fn} !!!!") if self.args.rnd_noise: pert_joints_offset_sequence = noise_joints_offset_sequence # pert_avg_joints_sequence = noise_avg_joints_sequence if not self.args.use_arti_obj: ## minus normed base pts here ## # ### normed pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # the strategy of adding noise to the representations # # tot_pert_joint = pert_joints_offset_sequence * std_joints_sequence.unsqueeze(1) + pert_avg_joints_sequence.unsqueeze(1) # np.save("tot_pert_joint.npy", tot_pert_joint.detach().cpu().numpy()) ####### E ####### # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals # # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # # noise_e_disp_rel_to_base_along_normals = torch.randn_like(e_disp_rel_to_base_along_normals) # pert_e_disp_rel_to_base_along_normals = self.q_sample( # e_disp_rel_to_base_along_normals, my_t, noise_e_disp_rel_to_base_along_normals # ) # noise_e_disp_rel_to_base_vt_normals = torch.randn_like(e_disp_rel_to_baes_vt_normals) # pert_e_disp_rel_to_base_vt_normals = self.q_sample( # e_disp_rel_to_baes_vt_normals, my_t, noise_e_disp_rel_to_base_vt_normals # ) ####### E ####### input_data = { 'base_pts': base_pts, 'base_normals': base_normals, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # 'pert_rhand_joints': pert_normed_rhand_joints, # 'pert_rhand_joints': pert_scaled_rhand_joints, 'rhand_joints': rhand_joints, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # 'avg_joints_sequence': avg_joints_sequence, # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'normed_base_pts': normed_base_pts, 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, ## pert_rel_base_pts_to_joints_for_jts_pred for the bsz x nf x nnj x nnb x 3 --> from base points to joints #### } ####### E ####### # primal space denoising -> # input_data.update( # { # 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, # 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, # 'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals, # 'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals, # } # ) ####### E ####### # input # input_data.update(init_image_avg_std_stats) input_data['rhand_joints'] = rhand_joints # normed # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": input_data.update( { 'avg_rel_base_pts_to_rhand_joints': avg_rel_base_pts_to_rhand_joints, 'std_rel_base_pts_to_rhand_joints': std_rel_base_pts_to_rhand_joints, } ) if self.args.train_enc: # # model(input_data, self._scale_timesteps(t).clone()) out_dict = model(input_data, self._scale_timesteps(my_t).clone()) obj_base_pts_feats = out_dict['obj_base_pts_feats'] # obj base pts feats # # noise_obj_base_pts_feats = torch.zeros_like(obj_base_pts_feats) noise_obj_base_pts_feats = torch.randn_like(obj_base_pts_feats) pert_obj_base_pts_feats = self.q_sample( obj_base_pts_feats.permute(1, 0, 2), my_t, noise_obj_base_pts_feats.permute(1, 0, 2) ).permute(1, 0, 2) if self.args.rnd_noise: pert_obj_base_pts_feats = noise_obj_base_pts_feats input_data['pert_obj_base_pts_feats'] = pert_obj_base_pts_feats model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. # from tqdm.auto import tqdm indices = tqdm(indices) for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # progress # # p_sample_with_grad ## p_sample with grid ##s # or for each joints -> the features -> sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample out = sample_fn( model, input_data, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ## const_noise=const_noise, ## # new representation strategies; resolve penerations; resolve penerations ### penetrations for new representations ## ) if self.diff_basejtsrel: # basejtrel # basejtsrel_seq_latents_sample = out["basejtsrel_seq_latents_sample"] ## basejtsrle output ## ## ## basejtsrel output ## # 'real_dec_basejtsrel': real_dec_basejtsrel, # 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, # if self.args.pred_joints_offset: # pred # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3 # basejtsrel_seq_latents_sample --> basejtsrel_seq_latents_sample # # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1) sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) # print(f"basejtsrel_seq_latents_sample: {basejtsrel_seq_latents_sample.size()}, normed_base_pts: {normed_base_pts.size()}") ### pert rel bae pts to rhand joints ### # ### normed base pts ## if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # print(f"Sampling with pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") basejtsrel_seq_dec_in_dict = { # finetune_with_cond # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ## 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # rhand joints ## ## and another choice ## another choice ## 'pert_joints_offset_sequence': out["basejtsrel_seq_latents_sample"], } input_data.update(basejtsrel_seq_dec_in_dict) else: # basejtsrel_seq_input_dict = {} basejtsrel_seq_dec_in_dict = {} if self.args.diff_realbasejtsrel_to_joints: # predicted x_{t-1} (normalized) ## # rel to joints dec_joints_offset_output = out['dec_joints_offset_output'] if not self.args.use_arti_obj: ## minus normed base pts here ## # from normed pts and offset outputs # pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # predicted x_{t-1} before normalization # sampled_rhand_joints = dec_joints_offset_output * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) realbasejtsrel_to_joints_dec_in_dict = { 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, # bsz x nf x nnj x nnb x 3 ## 'sampled_rhand_joints': sampled_rhand_joints, 'pert_joints_offset_sequence': dec_joints_offset_output, } input_data.update(realbasejtsrel_to_joints_dec_in_dict) if self.diff_realbasejtsrel : # real_dec_basejtsrel = out["real_dec_basejtsrel"] # bsz x nf x nnj x nnb x 3 # # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # add_noise_onjts, add_noise_onjts_single #### add_noise_onjts; add_noise_onjts_single #### if self.args.real_basejtsrel_norm_stra == "std" and (not self.args.add_noise_onjts) and (not self.args.add_noise_onjts_single): real_dec_basejtsrel_pred_sample = real_dec_basejtsrel * std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) + avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) else: # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel.clone() if not self.args.use_arti_obj: real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(2) # real dec basejtsrel pred sample # # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # # # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) # real pred samples # # if self.args.use_t == 1000: # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # else: # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # bsz x nf x nnj x 3 # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] if self.args.sel_basepts_idx >= 0: sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] else: sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., 0, :] # std joints sequence; # std_joints # sampled_rhand_joints = sampled_rhand_joints * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) real_basejtsrel_dec_in_dict = { # real_dec_basejtsrel # 'pert_rel_base_pts_to_rhand_joints': real_dec_basejtsrel, ## realdecbasejtsrel # 'sampled_rhand_joints': sampled_rhand_joints, } if not self.diff_basejtsrel: real_basejtsrel_dec_in_dict['sampled_rhand_joints'] = sampled_rhand_joints if self.args.train_enc: real_basejtsrel_dec_in_dict['pert_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] input_data.update(real_basejtsrel_dec_in_dict) else: real_basejtsrel_dec_in_dict = {} if self.diff_basejtse: ## seq latents ## seq latents ## # basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"] pert_dec_e_along_normals = out["dec_e_along_normals"] pert_dec_e_vt_normals = out["dec_e_vt_normals"] dec_e_along_normals = pert_dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals'] dec_e_vt_normals = pert_dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals'] ## dec_e_along_normals = torch.clamp(dec_e_along_normals, min=0.) dec_e_vt_normals = torch.clamp(dec_e_vt_normals, min=0.) # scale base ## model constraints and model impacts from object a to object c ## basejtse_seq_input_dict = { 'pert_e_disp_rel_to_base_along_normals': pert_dec_e_along_normals, 'pert_e_disp_rel_to_base_vt_normals': pert_dec_e_vt_normals, 'e_disp_rel_to_base_along_normals': dec_e_along_normals, 'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtse_seq_input_dict) else: basejtse_seq_input_dict = {} # basejtse_seq_dec_in_dict = {} yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # if self.args.train_diff: # set enc to evals # # # print(f"Setitng encoders to eval mode") # model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz ## ## rot2xyz ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # base normals # base normals # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] # if self.args.use_anchors: # rhand_joints: bsz x nf x nn_anchors x 3 # ## rhand verts ## rhand_joints = x_start['rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand # # base_pts, base_normals, rhand_joints # ### rhand verts ## avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## std_joints_sequence = torch.std(rhand_joints, dim=1) if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: # std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # bsz x 3 --> bsz x 1 x 3; std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) # if self.args.use_anchor: # avg_joints_sequence = torch.mean(rhand_joints, dim=1) # normed_base_pts, joints_offset_sequence # joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # bsz x nf x nnj x 3 # normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 if not self.args.use_arti_obj: normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 else: # base_pts: bsz x nf x nnb x 3 # normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # if self.args.jts_sclae_stra == "std": ## jts scale stra ## joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) # normed_base_pts = normed_base_pts / std_joints_sequence if not self.args.use_arti_obj: normed_base_pts = normed_base_pts / std_joints_sequence else: normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] # if 'sampled_base_pts_nearest_obj_pc' in x_start: # ambient_xstart_dict = { # 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], # 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], # } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # if self.args.wo_e_normalization: # per frame avg disp along normals # x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) # psatial -> e normalization and centralize? if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel']) x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel']) # normed_base_pts, joints_offset_sequence # ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## ## base pts to rhand joints ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## ## relative joint positions ### ## bsz x # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## # x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # # x_start['per_frame_avg_joints_rel'] = torch # # bsz x ws x nnj x nnb x 3 # # per_frame_avg_joints_rel # # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # rel_base_pts_to_rhand_joints ## rel_base_pts_to_rhand_joints -> joints offset ## Normalization stra 1 --> no normalization for joints sequences ## # normed base pts if not self.args.use_arti_obj: rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints # else: rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # other normalization strategies> if self.args.real_basejtsrel_norm_stra == "mean": rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # #### rel_base_pts_to_rhand_joints #### # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True)) elif self.args.real_basejtsrel_norm_stra == "std": # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # rel_base_pts_to_rhand_joints --> # # print("here using std!!") # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 #### rel_base_pts_to_rhand_joints #### # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) ## rep; motion-to-rep; # construct statistics, normalize values # # joints_scaling_factor = 5. # ''' GET rel and dists ''' ## rep and rhand_joints ##### # rep; motion-to-rep # if self.denoising_stra == "rep": # rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 # if not self.args.use_arti_obj: denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 else: denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(2) # bsz x nf x nnj x nnb x 3 ##### E ###### # ### Calculate moving related energies ### # # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## # # denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel'] # # denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( # denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # denormed relative distances # ) ## l2 real base pts # k_f = 1. ## l2 rel base pts to pert rhand joints ## # # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) # ### att_forces ## # att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # # bsz x (ws - 1) x nnj x nnb # # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # # bsz x (ws - 1) x nnj x 3 --> displacements s# # denormed_rhand_joints_disp = rhand_joints[:, 1:, :, :] - rhand_joints[:, :-1, :, :] # # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) ## signed dist base pts to rhand joints along normals # # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) # dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( # rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 # )) # k_a = 1. # k_b = 1. ### # e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # # (ws - 1) x nnj x nnb # -> dist vt normals # ## # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # ##### E ###### # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] elif self.denoising_stra == "motion_to_rep": # or sdfs. or # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints ## pred jts # pred jts # ### Add noise to rel_baes_pts_to_rhand_joints ### noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ## rel_base_pts_to_rhand_joints, t, noise_rel_base_pts_to_rhand_joints ) # add_noise_onjts, add_noise_onjts_single ### ==== add noise on joints ==== ### ### ==== add noise on joints and then use them to calculate the perturbed rel-base-pts-to-rhand-joints ==== ### if self.args.add_noise_onjts: # add_noise_onjts_single # --> # bsz x nf x nnjts x nn_base_pts x 3 joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(-2), 1) noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 ## joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints ) # pert_rel_base_pts_to_rhand_joints: bsz x seq_len x nnj x nnb x 3 --> the rhand-joints; if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) # else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(2) elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single # noise_joints_offset_output = torch.randn_like(joints_offset_sequence) pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_sequence, t, noise_joints_offset_output ) if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) # joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) # noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # noise_rel_base_pts_to_rhand_joints = noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :].repeat(1, 1, 1, normed_base_pts.size(1), 1) # pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints # ) # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) if self.args.train_enc: pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints # # bsz x ws x nnj x nnb x 3 # maxx_pert_basejtsrel, _ = torch.max(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_pert_basejtsrel, _ = torch.min(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # maxx_basejtsrel, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_basejtsrel, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # # print(f"maxx_pert_basejtsrel: {maxx_pert_basejtsrel}, minn_pert_basejtsrel: {minn_pert_basejtsrel}, maxx_basejtsrel: {maxx_basejtsrel}, minn_basejtsrel: {minn_basejtsrel}") ### Add noise to avg joints sequence # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... avg_joints_sequence, t, noise_avg_joints_sequence ) ### perturbe offset-joints ### # joints_offset_sequence noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) pert_joints_offset_sequence = self.q_sample( joints_offset_sequence, t, noise_joints_offset_sequence ) ### perturbe offset-joints ### # rel bae pts to if not self.args.use_arti_obj: pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) if self.args.use_jts_pert_realbasejtsrel: pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_joints_for_jts_pred noise_rel_base_pts_to_rhand_joints = noise_joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(-2), 1).contiguous() if self.args.finetune_with_cond: if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) # x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) # x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) # bigpicture # # x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) # ##### E ###### # if not self.args.wo_e_normalization and self.args.e_normalization_stra == "cent": # # e normalization --> # bsz = e_disp_rel_to_base_along_normals.size(0) # nf, nnj, nnb = e_disp_rel_to_base_along_normals.size()[1:] # high dimensional # # the max value and min value of all values # #bs z x nnf x nnj x nnb --> for the along normals values and vt normals values ## # maxx_e_disp_rel_to_base_along_normals, _ = torch.max(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_along_normals , _ = torch.min(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_vt_normals , _ = torch.max(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_vt_normals , _ = torch.min(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_along_normals = maxx_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_along_normals = minn_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # maxx_e_disp_rel_to_base_vt_normals = maxx_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_vt_normals = minn_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # x_start['per_frame_avg_disp_along_normals'] = (maxx_e_disp_rel_to_base_along_normals + minn_e_disp_rel_to_base_along_normals) / 2. # x_start['per_frame_avg_disp_vt_normals'] = (maxx_e_disp_rel_to_base_vt_normals + minn_e_disp_rel_to_base_vt_normals) / 2. # x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) # x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals'] ) # normalize ## # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] ##### E ###### ##### E ###### # # noise_e_disp_rel_to_base_vt_normals # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals 3 # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # # noise_e_disp_rel_to_base_along_normals = torch.zeros_like(e_disp_rel_to_base_along_normals) # pert_e_disp_rel_to_base_along_normals = self.q_sample( # e_disp_rel_to_base_along_normals, t, noise_e_disp_rel_to_base_along_normals # ) # # noise_e_disp_rel_to_base_vt_normals = torch.zeros_like(e_disp_rel_to_baes_vt_normals) # pert_e_disp_rel_to_base_vt_normals = self.q_sample( # e_disp_rel_to_baes_vt_normals, t, noise_e_disp_rel_to_base_vt_normals # ) ##### E ###### input_data = { 'base_pts': base_pts.clone(), # base pts ### 'base_normals': base_normals.clone(), # base normals ### 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), 'rhand_joints': rhand_joints, 'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ## 'pert_avg_joints_sequence': pert_avg_joints_sequence, 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'normed_base_pts': normed_base_pts, 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, } # if 'sampled_base_pts_nearest_obj_pc' in x_start: # input_data.update(ambient_xstart_dict) # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # bsz x ws - 1 x nnj x nnb # # input_data ##### E ###### # input_data.update( # { # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # ### e_disp_rel_to_base_along_normals: # 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, # 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, # 'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals, # 'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals, # } # ) ##### E ###### # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # # if self.args.train_diff: # with torch.no_grad(): # out_dict = model(input_data, self._scale_timesteps(t).clone()) # else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. terms['rot_mse'] = 0. ### diff_jts ### # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_jts: ### Sample for perturbed joints seq latents ### clean_joint_seq_latents = out_dict["joint_seq_output"] noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) if self.args.const_noise: noise_joint_seq_latents = noise_joint_seq_latents[0].unsqueeze(0).repeat(noise_joint_seq_latents.size(0), 1, 1).contiguous() pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() ### Sample for perturbed joints seq latents ### dec_in_dict['joints_seq_latents'] = pert_joint_seq_latents dec_in_dict['joints_seq_latents_enc'] = clean_joint_seq_latents if self.args.kl_weights > 0. and "joint_seq_output_mean" in out_dict and not self.args.train_diff: # clean_joint_seq_latents: seq_len x bs x d # log_p_joints_seq = model_util.standard_normal_logprob(clean_joint_seq_latents) log_p_joints_seq = log_p_joints_seq.permute(1, 0, 2).contiguous() # log_p_joints_seq = log_p_joints_seq.sum(dim=-1).mean(dim=-1) # log_p_joints_seq entropy_joints_seq = model_util.gaussian_entropy(out_dict['joint_seq_output_logvar'].permute(1, 2, 0)).mean(dim=-1) loss_prior_joints_seq = (- log_p_joints_seq - entropy_joints_seq) KL_loss += loss_prior_joints_seq # the dimension of latents ? ## if self.args.diff_realbasejtsrel_to_joints: joints_offset_output = out_dict["joints_offset_output"] if self.args.pred_diff_noise: jts_pred_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 else: jts_pred_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # terms['jts_pred_loss'] = jts_pred_loss terms['rot_mse'] += jts_pred_loss if self.diff_realbasejtsrel: # real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints if self.args.pred_diff_noise and not self.args.train_enc: # print(f"here predicting diff_noise...") if self.args.use_jts_pert_realbasejtsrel: # print(f"use_jts_pert_realbasejtsrel!!!") jts_pred_loss = torch.sum(( real_dec_basejtsrel[:, :, :, 0:1, :] - noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # jts_pred_loss = torch.sum(( # real_dec_basejtsrel - noise_joints_offset_sequence # ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) else: jts_pred_loss = torch.sum(( real_dec_basejtsrel - noise_rel_base_pts_to_rhand_joints ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) else: jts_pred_loss = torch.sum(( real_dec_basejtsrel - rel_base_pts_to_rhand_joints ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['jts_pred_loss'] = jts_pred_loss terms['rot_mse'] += jts_pred_loss if self.args.train_enc: obj_base_pts_feats = out_dict['obj_base_pts_feats'].detach() noise_obj_base_pts_feats = th.randn_like(obj_base_pts_feats) ## bsz x ws x nnjts x 3 -> for each joint point # pert_obj_base_pts_feats = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### obj_base_pts_feats.permute(1, 0, 2), t, noise_obj_base_pts_feats.permute(1, 0, 2) ).permute(1, 0, 2) # seq_len x bsz x nn_feats_dim # dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats( pert_obj_base_pts_feats, self._scale_timesteps(t).clone()) if self.args.pred_diff_noise: obj_base_pts_feats_denoising_loss = torch.sum( (dec_obj_base_pts_feats - noise_obj_base_pts_feats) ** 2, dim=-1 ) / noise_obj_base_pts_feats.size(-1) obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1) else: obj_base_pts_feats_denoising_loss = torch.sum( (dec_obj_base_pts_feats - obj_base_pts_feats) ** 2, dim=-1 ) / obj_base_pts_feats.size(-1) obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1) terms['jts_latent_denoising_loss'] = obj_base_pts_feats_denoising_loss terms['rot_mse'] += obj_base_pts_feats_denoising_loss if self.diff_basejtsrel: if 'basejtsrel_output' in out_dict: basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous() avg_jts_outputs = out_dict['avg_jts_outputs'] # print(f"basejtsrel_output: {basejtsrel_output.size()}, noise_rel_base_pts_to_rhand_joints: {noise_rel_base_pts_to_rhand_joints.size()}, rel_base_pts_to_rhand_joints: {rel_base_pts_to_rhand_joints.size()}") if self.args.pred_diff_noise: basejtsrel_denoising_loss = torch.sum((basejtsrel_output - noise_rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) else: basejtsrel_denoising_loss = torch.sum((basejtsrel_output - rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) else: joints_offset_output = out_dict['joints_offset_output'] if self.args.pred_diff_noise: basejtsrel_denoising_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 --> mean and mena over dim=-1 else: # # basejtsrel denoising losses ## basejtsrel_denoising_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # if 'avg_jts_outputs' in out_dict: # avg jts outputs ## avg_jts_outputs = out_dict['avg_jts_outputs'] if self.args.pred_diff_noise: avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) else: avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) else: avgjts_denoising_loss = torch.zeros_like(basejtsrel_denoising_loss) terms['basejtrel_denoising_loss'] = basejtsrel_denoising_loss terms['avgjts_denoising_loss'] = avgjts_denoising_loss # # terms['rot_mse'] += basejtsrel_denoising_loss + avgjts_denoising_loss # jts denoising ## if self.diff_basejtse: ### Sample for perturbed basejtsrel seq latents ### dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] # pred_e_along_vt_loss, # dec_e_along_normals; dec_e_vt_normals; # # bszx xnnj x nnb x 1 # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals 3 if self.args.pred_diff_noise: ## predict # predict e along and vt nromals ## pred_e_along_normals_loss = ((dec_e_along_normals - noise_e_disp_rel_to_base_along_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) pred_e_along_vt_loss = ((dec_e_vt_normals - noise_e_disp_rel_to_base_vt_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['basejtse_along_normals_pred_loss'] = pred_e_along_normals_loss terms['basejtse_vt_normals_pred_loss'] = pred_e_along_vt_loss terms['rot_mse'] += pred_e_along_normals_loss + pred_e_along_vt_loss # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } # and we show results on the challenging cases # class GaussianDiffusionV9: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64)) # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( # posterior mean coefs # (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. # canon_verts1 self.sel_faces_values = None self.canon_verts1 = None self.canon_sel_faces_values = None self.mano_path = "/data1/sim/mano_models/mano/models" # self.mano_layer = ManoLayer( flat_hand_mean=True, side='right', mano_root=self.mano_path, # mano_root # ncomps=24, use_pca=True, root_rot_mode='axisang', joint_rot_mode='axisang' ).cuda() ''' Load statistics ''' # avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" # std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" # avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" # std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" # avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) # std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) # avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) # std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) # ## self.avg_joints_rel, self.std_joints_rel # ## self.avg_joints_dists, self.std_joints_dists # self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() # self.std_joints_rel = torch.from_numpy(std_joints_rel).float() # self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() # self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # # self.maxx_rel, minn_rel, maxx_dists, minn_dists # # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" # rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() # maxx_rel = rel_dists_stats['maxx_rel'] # minn_rel = rel_dists_stats['minn_rel'] # maxx_dists = rel_dists_stats['maxx_dists'] # minn_dists = rel_dists_stats['minn_dists'] # self.maxx_rel = torch.from_numpy(maxx_rel).float() # self.minn_rel = torch.from_numpy(minn_rel).float() # self.maxx_dists = torch.from_numpy(maxx_dists).float() # self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' # avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" # std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" # avg_jts = np.load(avg_jts_fn, allow_pickle=True) # std_jts = np.load(std_jts_fn, allow_pickle=True) # # self.avg_jts, self.std_jts # # self.avg_jts = torch.from_numpy(avg_jts).float() # self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ### variance xxx noise ### ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def phy_projct_pred_joints(self, pred_joints, base_pts, base_normals): # pred_joints: bsz x nf x nn_jts x 3 # # base_pts: bsz x nn_base_pts x 3 # # base_normals: bsz x nn_base_pts x 3 # nf = pred_joints.size(1) if not self.args.use_arti_obj: base_pts_exp = base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() else: base_pts_exp = base_pts.clone() base_normals_exp = base_normals.clone() nearest_pred_joints_to_base_pts = torch.sum( (pred_joints.unsqueeze(-2) - base_pts_exp.unsqueeze(2)) ** 2, dim=-1 ) nearest_dist, nearest_base_pts_idxes = torch.min(nearest_pred_joints_to_base_pts, dim=-1) # bsz x nf x nn_jts nearest_dist = torch.sqrt(nearest_dist) nearest_base_pts = model_util.batched_index_select_ours(base_pts_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 nearest_base_normals = model_util.batched_index_select_ours(base_normals_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 # jts_to_base_pts = pred_joints - nearest_base_pts # from base pts to pred joints # dot_rel_with_normals = torch.sum( jts_to_base_pts * nearest_base_normals, dim=-1 # bsz x nf x nn_jts --> joints inside of the object # ) jts_proj_dir = torch.zeros_like(nearest_base_pts) # bsz x nf x nn_jts x 3 # jts_proj_dir[dot_rel_with_normals < 0.] = jts_to_base_pts[dot_rel_with_normals < 0.] # bsz x nf x nn_jts x 3 # return jts_proj_dir # bsz x nf x nn_jts x 3 # returned gradients # def select_nearest_base_pts_via_normals(self, nearest_base_pts, nearest_base_normals): # nearest_base_pts: bsz x nf x nn_jts x 3 # nearest_base_normals: bsz x nf x nn_jts x 3 # bsz, nf, nn_jts = nearest_base_pts.size()[:3] # nearest_base_pts # bsz x nf x nn_jts x 3 # new_nearest_base_pts = [nearest_base_pts[:, 0:1, :]] # new_nearest_base_normals = [nearest_base_normals[:, 0:1, :]] for i_f in range(1, nf): cur_nearest_base_pts = nearest_base_pts[:, i_f] # bsz x nn_jts x 3 # cur_nearest_base_normals = nearest_base_normals[:, i_f] # bsz x nn_jts x 3 # prev_nearest_base_pts = nearest_base_pts[:, i_f - 1] # bsz x nn_jts x 3 # prev_nearest_base_normals = nearest_base_normals[:, i_f - 1] # bsz x nn_jts x 3 # dot_cur_n_with_prev_n = torch.sum( # cur nearest base normals; nearest base normals # cur_nearest_base_normals * prev_nearest_base_normals, dim=-1 # bsz x nn_jts for the nearest base normals and prev normals # ) cur_new_nearest_pts = cur_nearest_base_pts.clone() # bsz x nn_jts x 3 # for the base normals # cur_new_nearest_normals = cur_nearest_base_normals.clone() # bsz x nn_jts x 3 # for base normals # # if less than dot_cur_n_with_prev_n # then use thres = -0.3 cur_new_nearest_pts[dot_cur_n_with_prev_n < thres] = prev_nearest_base_pts[dot_cur_n_with_prev_n < thres] # the dot prev cur_new_nearest_normals[dot_cur_n_with_prev_n < thres] = prev_nearest_base_normals[dot_cur_n_with_prev_n < thres] # new nearest base pts # # # # cur new nearest pts, normals # new_nearest_base_pts.append(cur_new_nearest_pts.unsqueeze(1)) new_nearest_base_normals.append(cur_new_nearest_normals.unsqueeze(1)) # new_nearest_base_pts = torch.cat(new_nearest_base_pts, dim=1) # bsz x nf x nn_jts x 3 --> for nearest base_pts and normals # new_nearest_base_normals = torch.cat(new_nearest_base_normals, dim=1) # bsz x nf x nn_jts x 3 --> for nearest base pts and normals # return new_nearest_base_pts, new_nearest_base_normals def select_nearest_base_pts_via_normals_fr_mid(self, nearest_base_pts, nearest_base_normals): # nearest_base_pts: bsz x nf x nn_jts x 3 # nearest_base_normals: bsz x nf x nn_jts x 3 # bsz, nf, nn_jts = nearest_base_pts.size()[:3] # nearest_base_pts # bsz x nf x nn_jts x 3 # key_frame_idx = 50 key_frame_idx = 30 new_nearest_base_pts = [nearest_base_pts[:, key_frame_idx: key_frame_idx + 1, :]] # new_nearest_base_normals = [nearest_base_normals[:, key_frame_idx: key_frame_idx + 1, :]] # new_nearest_base_pts = [nearest_base_pts[:, 0:1, :]] # # new_nearest_base_normals = [nearest_base_normals[:, 0:1, :]] for i_f in range(key_frame_idx - 1, -1, -1): cur_nearest_base_pts = nearest_base_pts[:, i_f] # bsz x nn_jts x 3 # cur_nearest_base_normals = nearest_base_normals[:, i_f] # bsz x nn_jts x 3 # prev_nearest_base_pts = new_nearest_base_pts[0].squeeze(1) # bsz x nn_jts x 3 # prev_nearest_base_normals = new_nearest_base_normals[0].squeeze(1) # bsz x nn_jts x 3 dot_cur_n_with_prev_n = torch.sum( # cur nearest base normals; nearest base normals # cur_nearest_base_normals * prev_nearest_base_normals, dim=-1 # bsz x nn_jts for the nearest base normals and prev normals # ) cur_new_nearest_pts = cur_nearest_base_pts.clone() # bsz x nn_jts x 3 # for the base normals # cur_new_nearest_normals = cur_nearest_base_normals.clone() # bsz x nn_jts x 3 # for base normals # # if less than dot_cur_n_with_prev_n # then use thres = -0.0 thres = 0.7 cur_new_nearest_pts[dot_cur_n_with_prev_n < thres] = prev_nearest_base_pts[dot_cur_n_with_prev_n < thres] # the dot prev cur_new_nearest_normals[dot_cur_n_with_prev_n < thres] = prev_nearest_base_normals[dot_cur_n_with_prev_n < thres] # new nearest base pts # # new_nearest_base_pts = [cur_new_nearest_pts.unsqueeze(1)] + new_nearest_base_pts new_nearest_base_normals = [cur_new_nearest_normals.unsqueeze(1)] + new_nearest_base_normals for i_f in range(key_frame_idx + 1, nf, 1): cur_nearest_base_pts = nearest_base_pts[:, i_f] # bsz x nn_jts x 3 # cur_nearest_base_normals = nearest_base_normals[:, i_f] # bsz x nn_jts x 3 # prev_nearest_base_pts = new_nearest_base_pts[-1].squeeze(1) # bsz x nn_jts x 3 # prev_nearest_base_normals = new_nearest_base_normals[-1].squeeze(1) # bsz x nn_jts x 3 ## bsz x nn_jts x 3 ## # dot_cur_n_with_prev_n = torch.sum( # cur nearest base normals; nearest base normals # cur_nearest_base_normals * prev_nearest_base_normals, dim=-1 # bsz x nn_jts for the nearest base normals and prev normals # ) cur_new_nearest_pts = cur_nearest_base_pts.clone() # bsz x nn_jts x 3 # for the base normals # cur_new_nearest_normals = cur_nearest_base_normals.clone() # bsz x nn_jts x 3 # for base normals # # if less than dot_cur_n_with_prev_n # then use thres = -0.0 thres = 0.7 cur_new_nearest_pts[dot_cur_n_with_prev_n < thres] = prev_nearest_base_pts[dot_cur_n_with_prev_n < thres] # the dot prev cur_new_nearest_normals[dot_cur_n_with_prev_n < thres] = prev_nearest_base_normals[dot_cur_n_with_prev_n < thres] # new nearest base pts # # # # cur new nearest pts, normals # new_nearest_base_pts.append(cur_new_nearest_pts.unsqueeze(1)) new_nearest_base_normals.append(cur_new_nearest_normals.unsqueeze(1)) # # for i_f in range(1, nf): # cur_nearest_base_pts = nearest_base_pts[:, i_f] # bsz x nn_jts x 3 # # cur_nearest_base_normals = nearest_base_normals[:, i_f] # bsz x nn_jts x 3 # # prev_nearest_base_pts = nearest_base_pts[:, i_f - 1] # bsz x nn_jts x 3 # # prev_nearest_base_normals = nearest_base_normals[:, i_f - 1] # bsz x nn_jts x 3 # # # dot_cur_n_with_prev_n = torch.sum( # cur nearest base normals; nearest base normals # # cur_nearest_base_normals * prev_nearest_base_normals, dim=-1 # bsz x nn_jts for the nearest base normals and prev normals # # ) # cur_new_nearest_pts = cur_nearest_base_pts.clone() # bsz x nn_jts x 3 # for the base normals # # cur_new_nearest_normals = cur_nearest_base_normals.clone() # bsz x nn_jts x 3 # for base normals # # # if less than dot_cur_n_with_prev_n # then use # thres = -0.3 # cur_new_nearest_pts[dot_cur_n_with_prev_n < thres] = prev_nearest_base_pts[dot_cur_n_with_prev_n < thres] # the dot prev # cur_new_nearest_normals[dot_cur_n_with_prev_n < thres] = prev_nearest_base_normals[dot_cur_n_with_prev_n < thres] # # new nearest base pts # # # # # cur new nearest pts, normals # # new_nearest_base_pts.append(cur_new_nearest_pts.unsqueeze(1)) # new_nearest_base_normals.append(cur_new_nearest_normals.unsqueeze(1)) # # for i_f in range(1, nf): # cur_nearest_base_pts = nearest_base_pts[:, i_f] # bsz x nn_jts x 3 # # cur_nearest_base_normals = nearest_base_normals[:, i_f] # bsz x nn_jts x 3 # # prev_nearest_base_pts = nearest_base_pts[:, i_f - 1] # bsz x nn_jts x 3 # # prev_nearest_base_normals = nearest_base_normals[:, i_f - 1] # bsz x nn_jts x 3 # # # dot_cur_n_with_prev_n = torch.sum( # cur nearest base normals; nearest base normals # # cur_nearest_base_normals * prev_nearest_base_normals, dim=-1 # bsz x nn_jts for the nearest base normals and prev normals # # ) # cur_new_nearest_pts = cur_nearest_base_pts.clone() # bsz x nn_jts x 3 # for the base normals # # cur_new_nearest_normals = cur_nearest_base_normals.clone() # bsz x nn_jts x 3 # for base normals # # # if less than dot_cur_n_with_prev_n # then use # thres = -0.3 # cur_new_nearest_pts[dot_cur_n_with_prev_n < thres] = prev_nearest_base_pts[dot_cur_n_with_prev_n < thres] # the dot prev # cur_new_nearest_normals[dot_cur_n_with_prev_n < thres] = prev_nearest_base_normals[dot_cur_n_with_prev_n < thres] # # new nearest base pts # # # # # cur new nearest pts, normals # # new_nearest_base_pts.append(cur_new_nearest_pts.unsqueeze(1)) # new_nearest_base_normals.append(cur_new_nearest_normals.unsqueeze(1)) # new_nearest_base_pts = torch.cat(new_nearest_base_pts, dim=1) # bsz x nf x nn_jts x 3 --> for nearest base_pts and normals # new_nearest_base_normals = torch.cat(new_nearest_base_normals, dim=1) # bsz x nf x nn_jts x 3 --> for nearest base pts and normals # return new_nearest_base_pts, new_nearest_base_normals def pure_ccd(self, hand_rot, hand_theta, hand_beta, hand_transl, input_data, base_pts_exp, ): bsz, nf = hand_transl.size(0), hand_transl.size(1) hand_transl.requires_grad_() hand_rot.requires_grad_() hand_theta.requires_grad_() # if nn_tot_iters == 0: # hand verts; penetrated # how to project the hand verts # hand_verts, hand_joints = self.mano_layer(torch.cat([hand_rot.view(bsz * nf, -1), hand_theta.view(bsz * nf, -1)], dim=-1), hand_beta.unsqueeze(1).repeat(1, nf, 1).view(-1, 10), hand_transl.view(bsz * nf, -1)) hand_verts = hand_verts.view(bsz, nf, 778, 3).contiguous() * 0.001 i_bsz = 0 # if wccd: ccd_hand_verts, sel_faces_values, canon_verts1, canon_sel_faces_values = common_utils.collision_loss_sim_sequence_ours_ccd_rigid(hand_verts[i_bsz], input_data['obj_verts'][i_bsz], input_data['obj_faces'][0][0], input_data['obj_faces'][0][0], base_pts_exp[i_bsz], input_data['obj_rot'][i_bsz], input_data['obj_transl'][i_bsz], use_delta=False, sel_faces_values=self.sel_faces_values, canon_verts1=self.canon_verts1, canon_sel_faces_values=self.canon_sel_faces_values) self.sel_faces_values = sel_faces_values self.canon_verts1 = canon_verts1 self.canon_sel_faces_values = canon_sel_faces_values ccd_hand_verts = ccd_hand_verts.detach() # input_data['obj_rot']: bsz x nf x 3 x 3; # input_data['obj_transl']: bsz x nf x 3; # 1 x nf x nn_verts x 3 # print(f"ccd_hand_verts: {ccd_hand_verts.size()},", "obj_rot:", input_data['obj_rot'].size(), "obj_transl", input_data['obj_transl'].size()) ccd_hand_verts = torch.matmul(ccd_hand_verts.unsqueeze(0), input_data['obj_rot']) + input_data['obj_transl'].unsqueeze(-2) sv_ccd_hand_vert_dict = { 'ccd_hand_verts': ccd_hand_verts.detach().cpu().numpy(), } ccd_hand_verts_sv_fn = "ccd_hand_verts.npy" np.save(ccd_hand_verts_sv_fn, sv_ccd_hand_vert_dict) learning_rate = 0.1 opt = optim.Adam([hand_transl, hand_rot, hand_theta], lr=learning_rate) nn_tot_iters = 1000 # 100 # nn_tot_iters = 20 # 100 for i_iter in range(nn_tot_iters): # hand verts; penetrated # how to project the hand verts # hand_verts, hand_joints = self.mano_layer(torch.cat([hand_rot.view(bsz * nf, -1), hand_theta.view(bsz * nf, -1)], dim=-1), hand_beta.unsqueeze(1).repeat(1, nf, 1).view(-1, 10), hand_transl.view(bsz * nf, -1)) hand_verts = hand_verts.view(bsz, nf, 778, 3).contiguous() * 0.001 diff_hand_verts = torch.sum( (hand_verts - ccd_hand_verts) ** 2, dim=-1 ).mean() print(f"i_iter: {i_iter}, diff_hand_verts: {diff_hand_verts.item()}") opt.zero_grad() diff_hand_verts.backward() opt.step() # hand verts; base pts exp # print(f"ccd hand verts saved to {ccd_hand_verts_sv_fn}") def phy_project_pred_verts(self, pred_params, hand_beta, base_pts, base_normals, input_data, nn_iters=10, wccd=False): # pred_param: bsz x nf x nn_params; # hand_beta: bsz x nn_beta_dim; ## # input_data std_transl = input_data['std_transl'] avg_transl = input_data['avg_transl'] std_rot = input_data['std_rot'] avg_rot = input_data['avg_rot'] std_theta = input_data['std_theta'] avg_theta = input_data['avg_theta'] with th.enable_grad(): bsz = pred_params.size(0) nf = pred_params.size(1) # if not self.args.use_arti_obj: # base_pts_exp = base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() # base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous() # else: base_pts_exp = base_pts.clone() base_normals_exp = base_normals.clone() # bz x nf x nn_params_dim; hand_transl, hand_rot, hand_theta = pred_params[..., :3].clone(), pred_params[..., 3: 6].clone(), pred_params[..., 6:].clone() # params hand_transl = torch.zeros_like(pred_params[..., :3]) hand_rot = torch.zeros_like(pred_params[..., 3: 6]) hand_theta = torch.zeros_like(pred_params[..., 6:]) hand_transl[:].data = pred_params[..., :3].data hand_rot[:].data = pred_params[..., 3: 6].data hand_theta[:].data = pred_params[..., 6:].data # base_pts_trans, base_normals_trans # # avg_transl, std_transl; avg_rot, std_rot ; avg_theta, std_theta # hand_transl = hand_transl * std_transl + avg_transl hand_rot = hand_rot * std_rot + avg_rot hand_theta = hand_theta * std_theta + avg_theta hand_transl.requires_grad_() hand_rot.requires_grad_() hand_theta.requires_grad_() # hand_beta.requires_grad_() learning_rate = 0.001 # penetration guidance; # learning_rate = 0.01 opt = optim.Adam([hand_transl, hand_rot, hand_theta], lr=learning_rate) nn_tot_iters = nn_iters # 100 # nn_tot_iters = 20 # 100 for i_iter in range(nn_tot_iters): # hand verts; penetrated # how to project the hand verts # hand_verts, hand_joints = self.mano_layer(torch.cat([hand_rot.view(bsz * nf, -1), hand_theta.view(bsz * nf, -1)], dim=-1), hand_beta.unsqueeze(1).repeat(1, nf, 1).view(-1, 10), hand_transl.view(bsz * nf, -1)) hand_verts = hand_verts.view(bsz, nf, 778, 3).contiguous() * 0.001 # hand verts; base pts exp # nearest_hand_verts_to_base_pts = torch.sum( (hand_verts.unsqueeze(-2) - base_pts_exp.unsqueeze(2)) ** 2, dim=-1 ) # hand verts and base pts ... nearest_dist, nearest_base_pts_idxes = torch.min(nearest_hand_verts_to_base_pts, dim=-1) # bsz x nf x nn_jts nearest_dist = torch.sqrt(nearest_dist) nearest_base_pts = model_util.batched_index_select_ours(base_pts_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 nearest_base_normals = model_util.batched_index_select_ours(base_normals_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 # # bsz x nf x nn_jts x 3 for the base normals # ori_jts_to_base_pts = hand_verts - nearest_base_pts ori_dot_rel_with_normals = torch.sum( ori_jts_to_base_pts * nearest_base_normals, dim=-1 ) ##### ===== proj_loss v1 ===== ##### # ### from the middle key frame ### # # select_nearest_base_pts_via_normals_fr_mid(self, nearest_base_pts, nearest_base_normals): # nearest_base_pts, nearest_base_normals = self.select_nearest_base_pts_via_normals_fr_mid(nearest_base_pts, nearest_base_normals) # # nearest_base_pts and normals # # ### from the st frame ### # # nearest_base_pts, nearest_base_normals = self.select_nearest_base_pts_via_normals(nearest_base_pts, nearest_base_normals) # jts_to_base_pts = hand_verts - nearest_base_pts # from base pts to pred joints # # dot_rel_with_normals = torch.sum( # jts_to_base_pts * nearest_base_normals, dim=-1 # bsz x nf x nn_jts --> joints inside of the object # # ) # # dot_rel_with_normals = torch.sum( # # jts_to_base_pts * nearest_base_normals, dim=-1 # bsz x nf x nn_jts --> joints inside of the object # # # ) # # proj_loss f # proj_loss = torch.mean( # mean of loss # # (jts_to_base_pts ** 2)[dot_rel_with_normals < 0.] # ) ##### ===== proj_loss v1 ===== ##### tot_proj_loss = 0. for i_bsz in range(hand_verts.size(0)): print(hand_verts.size(), len(input_data['obj_verts']), ) proj_loss, sel_faces_values = common_utils.collision_loss_sim_sequence_ours(hand_verts[i_bsz], input_data['obj_verts'][i_bsz], input_data['obj_faces'][0][0], input_data['obj_faces'][0][0], base_pts_exp[i_bsz], use_delta=False, sel_faces_values=self.sel_faces_values) self.sel_faces_values = sel_faces_values tot_proj_loss += proj_loss proj_loss = tot_proj_loss / float(hand_verts.size(0)) # proj_loss = torch.mean( # mean of loss # # (jts_to_base_pts ** 2)[ori_dot_rel_with_normals < 0.] # ) print(f"proj_loss: {proj_loss.item()}") # jts_proj_dir = torch.zeros_like(nearest_base_pts) # bsz x nf x nn_jts x 3 # # jts_proj_dir[dot_rel_with_normals < 0.] = jts_to_base_pts[dot_rel_with_normals < 0.] # bsz x nf x nn_jts x 3 # opt.zero_grad() proj_loss.backward() opt.step() if nn_tot_iters == 0: # hand verts; penetrated # how to project the hand verts # hand_verts, hand_joints = self.mano_layer(torch.cat([hand_rot.view(bsz * nf, -1), hand_theta.view(bsz * nf, -1)], dim=-1), hand_beta.unsqueeze(1).repeat(1, nf, 1).view(-1, 10), hand_transl.view(bsz * nf, -1)) hand_verts = hand_verts.view(bsz, nf, 778, 3).contiguous() * 0.001 i_bsz = 0 if wccd: ccd_hand_verts, sel_faces_values, canon_verts1, canon_sel_faces_values = common_utils.collision_loss_sim_sequence_ours_ccd_rigid(hand_verts[i_bsz], input_data['obj_verts'][i_bsz], input_data['obj_faces'][0][0], input_data['obj_faces'][0][0], base_pts_exp[i_bsz], input_data['obj_rot'][i_bsz], input_data['obj_transl'][i_bsz], use_delta=False, sel_faces_values=self.sel_faces_values, canon_verts1=self.canon_verts1, canon_sel_faces_values=self.canon_sel_faces_values) self.sel_faces_values = sel_faces_values self.canon_verts1 = canon_verts1 self.canon_sel_faces_values = canon_sel_faces_values ccd_hand_verts = ccd_hand_verts.detach() # input_data['obj_rot']: bsz x nf x 3 x 3; # input_data['obj_transl']: bsz x nf x 3; # 1 x nf x nn_verts x 3 # print(f"ccd_hand_verts: {ccd_hand_verts.size()},", "obj_rot:", input_data['obj_rot'].size(), "obj_transl", input_data['obj_transl'].size()) ccd_hand_verts = torch.matmul(ccd_hand_verts.unsqueeze(0), input_data['obj_rot']) + input_data['obj_transl'].unsqueeze(-2) sv_ccd_hand_vert_dict = { 'ccd_hand_verts': ccd_hand_verts.detach().cpu().numpy(), } ccd_hand_verts_sv_fn = "ccd_hand_verts.npy" np.save(ccd_hand_verts_sv_fn, sv_ccd_hand_vert_dict) learning_rate = 0.1 opt = optim.Adam([hand_transl, hand_rot, hand_theta], lr=learning_rate) nn_tot_iters = 1000 # 100 # nn_tot_iters = 20 # 100 for i_iter in range(nn_tot_iters): # hand verts; penetrated # how to project the hand verts # hand_verts, hand_joints = self.mano_layer(torch.cat([hand_rot.view(bsz * nf, -1), hand_theta.view(bsz * nf, -1)], dim=-1), hand_beta.unsqueeze(1).repeat(1, nf, 1).view(-1, 10), hand_transl.view(bsz * nf, -1)) hand_verts = hand_verts.view(bsz, nf, 778, 3).contiguous() * 0.001 diff_hand_verts = torch.sum( (hand_verts - ccd_hand_verts) ** 2, dim=-1 ).mean() print(f"i_iter: {i_iter}, diff_hand_verts: {diff_hand_verts.item()}") opt.zero_grad() diff_hand_verts.backward() opt.step() # hand verts; base pts exp # print(f"ccd hand verts saved to {ccd_hand_verts_sv_fn}") hand_transl = (hand_transl - avg_transl) / std_transl hand_rot = (hand_rot - avg_rot) / std_rot hand_theta = (hand_theta - avg_theta) / std_theta optimized_pred_params = torch.cat( # bsz x nf x (3 + 3 + 24) [hand_transl.detach(), hand_rot.detach(), hand_theta.detach()], dim=-1 ) return optimized_pred_params # # the full physical world here? ## def p_mean_variance_cond( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ # p_mean_varaince # if model_kwargs is None: model_kwargs = {} B = x['base_pts'].shape[0] assert t.shape == (B,) # print(f"t_shape: {t.shape}", "base_pts", x['base_pts'].size()) input_data = x out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} real_basejtsrel_seq_rt_dict = {} basejtsrel_seq_rt_dict = {} realbasejtsrel_to_joints_rt_dict = {} model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_realbasejtsrel and self.diff_basejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output # print(f"basejtsrel_output: {basejtsrel_output.size()}") # if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) # t_item = t[0].item() # alpha = self.var_sched.alphas[t_item] # alpha_bar = self.var_sched.alpha_bars[t_item] # sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # c0 = 1.0 / torch.sqrt(alpha) # c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # beta = self.var_sched.betas[[t[0].item()] * bsz] # z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) # basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta # else: # basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # combine those two things # if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: # add noise onjts # if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints # jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # jts_fr_basepts = jts_fr_basepts.mean(dim=-2) jts_fr_basepts = pert_rel_base_pts_outputs # pert # score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 score_jts_fr_basepts = real_dec_basejtsrel[..., self.args.sel_basepts_idx, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] # combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts)[..., -5:, :] # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + score_jts_fr_basepts[..., -5:, :] * 0.3 # combined_socre = combined_socre * 0.2 + score_jts_fr_basepts * 0.8 combined_socre = combined_socre * 0.1 + score_jts_fr_basepts * 0.9 # combined_socre = combined_socre * 0.05 + score_jts_fr_basepts * 0.95 # combined_socre = combined_socre * 0.5 + score_jts_fr_basepts * 0.5 # combined_socre = combined_socre # combined_socre = score_jts_fr_basepts else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts combined_socre = score_jts # not cmb finger # # combined_socre = score_jts_fr_basepts # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, score_jts_fr_basepts: {score_jts_fr_basepts.size()}, combined_socre: {combined_socre.size()}, score_jts: {score_jts.size()}") if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. # use_var_sched -> # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) # print(f"dec_jts_fr_basepts: {dec_jts_fr_basepts.size()}, normed_base_pts: ", x['normed_base_pts'].size(), "real_dec_basejtsrel:", real_dec_basejtsrel.size()) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, denormed_rel_base_pts_to_rhand_joints: {denormed_rel_base_pts_to_rhand_joints.size()}, jts_fr_basepts: {jts_fr_basepts.size()}") elif self.args.add_noise_onjts_single: # add noise on single joint if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] socre_jts_fr_basepts = dec_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # strategy 1 --> conditioning # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts)[..., -5:, :] # strategy 2 --> linear interpolation # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + socre_jts_fr_basepts[..., -5:, :] * 0.3 combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.5 + socre_jts_fr_basepts[..., -5:, :] * 0.5 else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = socre_jts_fr_basepts if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(combined_socre) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (pert_rel_base_pts_outputs - c1 * combined_socre) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: dec_jts_fr_basepts = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) else: # raise ValueError(f"Add noise directly --- not implemented yet") # # input_data # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints if not self.args.use_arti_obj: jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts_fr_basepts if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_jts_fr_basepts } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints raise ValueError(f"Trian enc --- Not implemented yet") pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) # else: # real_basejtsrel_seq_rt_dict = {} # basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel and self.args.diff_realbasejtsrel_to_joints: pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output_from_rel'] # joints offset output # score_jts_fr_rel = dec_joints_offset_output # # pert joints offset sequence # # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_rel # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_rel alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * score_jts) + sigma * z # theta ### realjtsrel_to_joints and joints only ## realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_joints_offset_output } # else: # realbasejtsrel_to_joints_rt_dict = {} # basejtsrel_seq_rt_dict = {} # rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) # rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['base_pts'].shape[0] assert t.shape == (B,) input_data = x ## dec_out and out ## ## output dict ## out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## posterior_log_variance_clipped, posterior_variance ## model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # pmean variance if self.diff_jts: # x_t ## joints seq latents ## pert_joints_seq_latents = x['joints_seq_latents'] # x_t pred_clean_joint_seq_latents = out_dict["joints_seq_latents"] ## if self.args.pred_diff_noise: ## eps -> estimated noises ## t > for added joints latents ## pred_clean_joint_seq_latents = self._predict_xstart_from_eps(pert_joints_seq_latents.permute(1, 0, 2), t=t, eps=pred_clean_joint_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # seq x bs x d # # minn_pert_joints_seq_latents, _ = torch.min(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # maxx_pert_joints_seq_latents, _ = torch.max(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # print(f"pred minn_pert_joints_latents: {minn_pert_joints_seq_latents[:10]}, pred maxx_pert_joints_seq_latents: {maxx_pert_joints_seq_latents[:10]}") ## out_dict["joint_seq_output"] = model.model.dec_jts_only_fr_latents(pred_clean_joint_seq_latents)["joint_seq_output"] ## joints seq latents mean # # pred_clean_joint_seq_latents = pert_joints_seq_latents ## joints seq latents mean # joints_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_clean_joint_seq_latents.permute(1, 0, 2), x_t=pert_joints_seq_latents.permute(1, 0, 2), t=t ) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_mean = joints_seq_latents_mean.permute(1, 0, 2) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_variance = _extract_into_tensor(model_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joints_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # joint seq output # # joint seq output # joint_seq_output = out_dict["joint_seq_output"] jts_seq_rt_dict = { ### joints seq latents ### "joints_seq_latents_mean": joints_seq_latents_mean, "joints_seq_latents_variance": joints_seq_latents_variance, "joints_seq_latents_log_variance": joints_seq_latents_log_variance, ### decoded output values ### "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} # /data1/sim/mdm/save/predoffset_stdscale_bsz_10_pred_diff_realbaesjtsrel_nonorm_std_for_norm_train_enc_with_diff_latents_prediffnoise_none_norm_rel_rel_to_jts_/model000007000.pt if self.args.diff_realbasejtsrel_to_joints: pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output'] # joints offset output # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> beta, z c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * dec_joints_offset_output) + sigma * z # theta ## use the predicted latents and pert_latents for the seq latents prediction ## dec_joints_offset_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=dec_joints_offset_output, x_t=pert_joints_offset_output, t=t ) ## from model_variance to basejtsrel_seq_latents ### dec_joints_offset_output_variance = _extract_into_tensor(model_variance, t, dec_joints_offset_output.shape) dec_joints_offset_output_log_variance = _extract_into_tensor(model_log_variance, t, dec_joints_offset_output.shape) realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) # jts_fr_basepts = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) dec_jts_fr_basepts = real_dec_basejtsrel # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # ## use same noise for rep ## use noise for rep ## ### use_same_noise_for_rep --> use same noise for rep ### if self.args.use_same_noise_for_rep: # # convert them to the strategy of using single noise ## if self.args.sel_basepts_idx >= 0: # real dec base jts rel # dec_jts_fr_basepts = real_dec_basejtsrel[:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1] else: dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2, keepdim=True) # dec noise # # [:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1] # from noise and x_t to x_start; # a projection strategy for x_start; # to noise # and we want to adjust nosie # if self.args.phy_guided_sampling and t[0].item() < 1: # # phy_guided_sampling # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 # if self.args.sel_basepts_idx >= 0: # pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # # else: # pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) # joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 # x_start_projed = pred_dec_jts - joints_proj_dir # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # # dec_ratio = 0.95 # dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + dec_jts_fr_basepts_projed * (1. - dec_ratio) if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_jts_fr_basepts) if t_item > 0 else torch.zeros_like(dec_jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # # real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta ## dec jts fr base pts ## # dec jts fr # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) if self.args.phy_guided_sampling and t[0].item() < 0: # phy_guided_sampling # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 pred_dec_jts = dec_jts_fr_basepts.clone() if self.args.sel_basepts_idx >= 0: pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # else: pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 x_start_projed = pred_dec_jts - joints_proj_dir # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # dec_ratio = 0. dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + x_start_projed.unsqueeze(-2) * (1. - dec_ratio) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(2) elif self.args.add_noise_onjts_single: if not self.args.use_arti_obj: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) # dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) # real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) if not self.args.use_arti_obj: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2) else: if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(pert_rel_base_pts_to_rhand_joints) if t_item > 0 else torch.zeros_like(pert_rel_base_pts_to_rhand_joints) # z = torch.zeros_like(pert_rel_base_pts_to_rhand_joints) real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta # dec_jts_fr_basepts = real_dec_basejtsrel + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # get dec_jts fr basepts # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # repeated basepts # real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # real dec # else: # x_{t-1} real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) ## use the predicted latents and pert_latents for the seq latents prediction ## real_dec_basejtsrel_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=real_dec_basejtsrel, x_t=pert_rel_base_pts_to_rhand_joints, t=t ) ## from model_variance to basejtsrel_seq_latents ### real_dec_basejtsrel_variance = _extract_into_tensor(model_variance, t, real_dec_basejtsrel.shape) real_dec_basejtsrel_log_variance = _extract_into_tensor(model_log_variance, t, real_dec_basejtsrel.shape) # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, "real_dec_basejtsrel_mean": real_dec_basejtsrel_mean, "real_dec_basejtsrel_variance": real_dec_basejtsrel_variance, "real_dec_basejtsrel_log_variance": real_dec_basejtsrel_log_variance, } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) else: real_basejtsrel_seq_rt_dict = {} # else: # x_{t-1} # real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) if self.diff_basejtsrel: pert_rhand_params = x['pert_rhand_params'] # rel base pts outputs # params_output = out_dict['params_output'] print(f"pert_rhand_params: {pert_rhand_params.size()}, params_output: {params_output.size()}") if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = pert_rhand_params.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # x_t = traj[t] # # beta = self.var_sched.betas[[t[0].item()] * bsz] # if mask is not None: # x_t = x_t * mask # e_theta = self.net(x_t, beta=beta, context=context) z = torch.randn_like(params_output) if t_item > 0 else torch.zeros_like(params_output) params_output = c0 * (pert_rhand_params - c1 * params_output) + sigma * z # theta else: params_output = self._predict_xstart_from_eps(pert_rhand_params, t=t, eps=params_output) # base_pts_trans, base_normals_trans # # avg_transl, std_transl; avg_rot, std_rot ; avg_theta, std_theta # if self.args.phy_guided_sampling and t[0].item() < 10: # base_pts, base_normals # # projed_params: bsz x nf x params_dim # projed_params = self.phy_project_pred_verts(params_output, x['rhand_betas'], x['base_pts_trans'], x['base_normals_trans'], x, nn_iters=0 if t[0].item() == 0 else 0, wccd=True if t[0].item() == 0 else False) coef = 0.0 params_output = params_output * coef + projed_params * (1. - coef) ### # if self.args.phy_guided_sampling and t[0].item() < 200: # # phy_guided_sampling # # # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 # pred_dec_jts = basejtsrel_output.clone() # # if self.args.sel_basepts_idx >= 0: # # pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 # # # else: # # pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 # # # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals) # joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 # x_start_projed = pred_dec_jts - joints_proj_dir # # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ## # # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 # # dec_ratio = 0. # basejtsrel_output = basejtsrel_output * dec_ratio + x_start_projed * (1. - dec_ratio) ## use the predicted latents and pert_latents for the seq latents prediction ## # basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior # x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t # ) # ## from model_variance to basejtsrel_seq_latents ### # basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape) # basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape) # basejtsrel_output = out_dict["basejtsrel_output"] print(f"params_output: {params_output.size()}") basejtsrel_seq_rt_dict = { ### basejtsrel seq latents ### # "avg_jts_outputs": avg_jts_outputs, # "basejtsrel_output_variance": basejtsrel_output_variance, # "basejtsrel_output_log_variance": basejtsrel_output_log_variance, # # "avg_jts_outputs_variance": avg_jts_outputs_variance, # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance, "params_output": params_output, } else: basejtsrel_seq_rt_dict = {} if self.diff_basejtse: dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] pert_e_along_normals = x['pert_e_disp_rel_to_base_along_normals'] pert_e_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # basejtsrel_output = out_dict['joints_offset_output'] if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = dec_e_along_normals.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_e_along_normals) if t_item > 0 else torch.zeros_like(dec_e_along_normals) dec_e_along_normals = c0 * (pert_e_along_normals - c1 * dec_e_along_normals) + sigma * z # theta z_vt_normals = torch.randn_like(dec_e_vt_normals) if t_item > 0 else torch.zeros_like(dec_e_vt_normals) dec_e_vt_normals = c0 * (pert_e_vt_normals - c1 * dec_e_vt_normals) + sigma * z_vt_normals # theta else: dec_e_along_normals = self._predict_xstart_from_eps(pert_e_along_normals, t=t, eps=dec_e_along_normals) dec_e_vt_normals = self._predict_xstart_from_eps(pert_e_vt_normals, t=t, eps=dec_e_vt_normals) # base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### # pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] # ### q-sampled latent mean here ### # basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( # x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t # ) # basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) # basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # # base_jts_e_feats = out_dict["base_jts_e_feats"] # dec_e_along_normals = out_dict["dec_e_along_normals"] # dec_e_vt_normals = out_dict["dec_e_vt_normals"] basejtse_seq_rt_dict = { ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_seq_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). # """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def judge_activated(self, target_setting): if target_setting: return 1 else: return 0 def p_sample( ## p sample ## self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ # p_sample for the p_ample # Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. # gaussian diffusion # :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ multi_activated = ( self.judge_activated(self.diff_jts) + self.judge_activated(self.args.diff_realbasejtsrel_to_joints) + self.judge_activated(self.diff_realbasejtsrel) + self.judge_activated(self.diff_basejtsrel) + self.judge_activated(self.diff_basejtse) ) > 1.5 if multi_activated: # print(f"Multiple settings activated! Using combined sampling...") p_mena_variance_fn = self.p_mean_variance_cond # p_mean else: # print(f"Single setting activated! Using single sampling...") p_mena_variance_fn = self.p_mean_variance out = p_mena_variance_fn( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} if self.diff_jts: # bsz x ws x nnj x nnb x 3 # joints_seq_latents_noise = th.randn_like(x['joints_seq_latents']) # print('const_noise', const_noise) # seq x bsz x latent_dim # if const_noise: print(f"joints latents hape, ", x['joints_seq_latents'].shape) joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1) # joints_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1))) ) # no noise when t == 0 # bsz x nseq x dim # #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... # joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2) # nseq x bsz x dim # joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # #### ==== joints_seq_latents ===== #### joint_seq_output = out["joint_seq_output"] jts_seq_rt_dict = { "joints_seq_latents_sample": joints_seq_latents_sample, "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.args.diff_realbasejtsrel_to_joints: ## args.pred to joints # dec_joints_offset_output = realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': out['dec_joints_offset_output'] } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: if self.args.train_enc or ( self.args.pred_diff_noise and self.args.use_var_sched): real_dec_basejtsrel = out['real_dec_basejtsrel'] else: real_dec_basejtsrel_noise = th.randn_like(out['real_dec_basejtsrel']) if const_noise: real_dec_basejtsrel_noise = real_dec_basejtsrel_noise[[0]].repeat(out['real_dec_basejtsrel'].shape[0], 1, 1, 1, 1) real_dec_basejtsrel_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['real_dec_basejtsrel'].shape) - 1))) ) real_dec_basejtsrel = out["real_dec_basejtsrel_mean"] + real_dec_basejtsrel_nonzero_mask * th.exp(0.5 * out["real_dec_basejtsrel_log_variance"]) * real_dec_basejtsrel_noise real_basejtsrel_rt_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, } if self.args.train_enc: real_basejtsrel_rt_dict['dec_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] else: real_basejtsrel_rt_dict = {} if self.diff_basejtsrel: if self.args.pred_diff_noise and self.args.use_var_sched: params_output = out['params_output'] else: ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### ### rel_base_pts_outputs mask ### basejtsrel_seq_latents_noise = th.randn_like(out['basejtsrel_output']) if const_noise: ## seq latents noise ## basejtsrel_seq_latents_noise = basejtsrel_seq_latents_noise[[0]].repeat(out['basejtsrel_output'].shape[0], 1, 1, 1, 1) basejtsrel_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['basejtsrel_output'].shape) - 1))) ) # no noise when t == 0 #### ==== basejtsrel_seq_latents ===== #### ## sample latent codes -> denoise latent codes basejtsrel_seq_latents_sample = out["basejtsrel_output"] + basejtsrel_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtsrel_output_log_variance"]) * basejtsrel_seq_latents_noise # basejtsrel_seq_latents_sample = basejtsrel_seq_latents_sample.permute(1, 0, 2) #### ==== basejtsrel_seq_latents ===== #### ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### basejtsrel_rt_dict = { "params_output": params_output, # "avg_jts_outputs_sample": avg_jts_outputs_sample, } else: basejtsrel_rt_dict = {} if self.diff_basejtse: # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### # ### rel_base_pts_outputs mask ### # basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # # print('const_noise', const_noise) # if const_noise: # basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) # basejtse_seq_latents_nonzero_mask = ( # (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) # ) # no noise when t == 0 # #### ==== basejtsrel_seq_latents ===== #### # basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) # basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) # #### ==== basejtsrel_seq_latents ===== #### # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## dec_e_vt_normals = out["dec_e_vt_normals"] basejtse_rt_dict = { # "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) rt_dict.update(real_basejtsrel_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final # score # socre p_sample_loop_progressive # def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # """ # p_sample loop progressive # Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] # base_pts, base_normals # bsz x nn_base_pts x 3 # base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] # rhand_joints = init_image['gt_rhand_joints'] if self.args.use_anchors: # rhand_joints: bsz x nf x nn_anchors x 3 # rhand_joints = init_image['pert_rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand # # rhand_transl, rhand_rot, rhand_theta # rhand_transl = init_image['rhand_transl'] # bsz x ws x 3 --> translational vectors # rhand_rot = init_image['rhand_rot'] # bsz x ws x 3 -> rot var # rhand_theta = init_image['rhand_theta'] # bsz x ws x 24 -> theta var # # base_pts_trans, base_normals_trans # # avg_transl, std_transl; avg_rot, std_rot ; avg_theta, std_theta # obj_rot, obj_transl # obj_rot = init_image['obj_rot'] obj_transl = init_image['obj_transl'] if self.args.use_arti_obj: base_pts_trans = torch.matmul(base_pts, obj_rot) + obj_transl.unsqueeze(2) # bsz x ws x nn_base_pts x 3 # base_normals_trans = torch.matmul(base_normals, obj_rot) else: base_pts_trans = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(2) # bsz x ws x nn_base_pts x 3 # base_normals_trans = torch.matmul(base_normals.unsqueeze(1), obj_rot) # rhand_transl, rhand_rot, rhand_theta # # avg_transl, std_transl; avg_rot, std_rot; avg_theta, std_theta # ### avg, std ### avg_transl, std_transl = torch.mean(rhand_transl, dim=1, keepdim=True), torch.std(rhand_transl, dim=1, keepdim=True) avg_rot, std_rot = torch.mean(rhand_rot, dim=1, keepdim=True), torch.std(rhand_rot, dim=1, keepdim=True) avg_theta, std_theta = torch.mean(rhand_theta, dim=1, keepdim=True), torch.std(rhand_theta, dim=1, keepdim=True) rhand_transl = (rhand_transl - avg_transl) / std_transl # xxx rhand_rot = (rhand_rot - avg_rot) / std_rot # bsz x ws x 3 # rhand_theta= (rhand_theta - avg_theta) / std_theta ### bsz x ws x 24 rhand_params = torch.cat( [rhand_transl, rhand_rot, rhand_theta], dim=-1 # bsz x ws x (3 + 3 + 24) ) # rhand_joints = rhand_joints - ## vage for whether this model can work ### # avg_joints_sequence = std_joints_sequence = torch.std(rhand_joints, dim=1) avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## ## if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # ws x 1 x 3 # std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # if self.args.jts_sclae_stra == "std": # and only use the latents # # joints_offset_sequence = joints_offset_sequence / std_joints_sequence if not self.args.use_arti_obj: normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 else: # base_pts: bsz x nf x nnb x 3 # normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # # joints_offset_sequence_ori = joints_offset_sequence.clone() # rhand_joints_ori = rhand_joints.clone() # jts scale stra # jts scale strategies ## # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) if self.args.jts_sclae_stra == "std": joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) if not self.args.use_arti_obj: normed_base_pts = normed_base_pts / std_joints_sequence else: normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # if 'sampled_base_pts_nearest_obj_pc' in init_image: # ambient_init_image = { # 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], # 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], # } ####### E ####### # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # if self.args.wo_e_normalization: # init_image['per_frame_avg_disp_along_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_along_normals']) # init_image['per_frame_avg_disp_vt_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_vt_normals']) # init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) # init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals']) ####### E ####### if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel']) init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel']) init_image_avg_std_stats = { 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) ### without e normalization ### # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") # joints_scaling_factor = 5. # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## # init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # # x_start['per_frame_avg_joints_rel'] = torch # # bsz x ws x nnj x nnb x 3 # # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] if not self.args.use_arti_obj: ### rel base pts to rhand joints #joints offset sequence # joints offset sequence # joints # joints_offset_sequence - normed_base_pts rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints # else: # rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # bsz x nf x nn_joints x nn_base_pts x 3 # maxx_rel_base_pts_to_rhand_joints, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) minn_rel_base_pts_to_rhand_joints, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) print(f"maxx_rel_base_pts_to_rhand_joints: {maxx_rel_base_pts_to_rhand_joints}, minn_rel_base_pts_to_rhand_joints: {minn_rel_base_pts_to_rhand_joints}") if self.args.real_basejtsrel_norm_stra == "mean": rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # #### rel_base_pts_to_rhand_joints # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True)) elif self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # rel exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) # noise_rhand_joints = th.randn_like(normed_rhand_joints) # pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) # rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) # scaled_rhand_joints = rhand_joints * joints_scaling_factor # noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) # pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> # pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) ####### E ####### # if not self.args.use_arti_obj: # # rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 # # denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # else: # denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(2) # # ### Calculate moving related energies ### # # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## # # denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * init_image['per_frame_std_joints_rel'] + init_image['per_frame_avg_joints_rel'] # # denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # # denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( # # denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # # ) ## l2 real base pts # k_f = 1. ## l2 rel base pts to pert rhand joints ## # # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) # ### att_forces ## # att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # # bsz x (ws - 1) x nnj x nnb # # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # # bsz x (ws - 1) x nnj x 3 --> displacements s# # # denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # denormed_rhand_joints_disp = rhand_joints[:, 1:, :, :] - rhand_joints[:, :-1, :, :] # # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # # if not self.args.use_arti_obj: # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) # else: # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals[:, :-1].unsqueeze(2) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals[:, :-1].unsqueeze(2) # bsz x nf x nn_joints x nn_base_pts x 3 --> rel # dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( # rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 # )) # k_a = 1. # k_b = 1. # ### # e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # # (ws - 1) x nnj x nnb # -> dist vt normals # ## # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal ####### E ####### # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) ### e_disp_rel_to_base_along_normals ### ---> # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ''' Normalization Strategy 4 ''' else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") # my_t = th.tensor([indices[-1]] * shape[0], device=device) my_t = th.tensor([indices[0]] * shape[0], device=device) # clean_joint_seq_latents = model(input_data, self._scale_timesteps(my_t)) # noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) # # pert_joint_seq_latents: bsz x seq x d # # pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() noise_rhand_params = th.randn_like(rhand_params) pert_rhand_params = self.q_sample( # bsz x ws x (3 + 3 + 24) rhand_params, my_t, noise_rhand_params ) ####### E ####### # if not self.args.wo_e_normalization and self.args.e_normalization_stra == "cent": # bsz = e_disp_rel_to_base_along_normals.size(0) # nf, nnj, nnb = e_disp_rel_to_base_along_normals.size()[1:] # high dimensional # # the max value and min value of all values # #bs z x nnf x nnj x nnb --> for the along normals values and vt normals values ## # maxx_e_disp_rel_to_base_along_normals, _ = torch.max(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_along_normals , _ = torch.min(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_vt_normals , _ = torch.max(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_vt_normals , _ = torch.min(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_along_normals = maxx_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_along_normals = minn_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # maxx_e_disp_rel_to_base_vt_normals = maxx_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_vt_normals = minn_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # init_image['per_frame_avg_disp_along_normals'] = (maxx_e_disp_rel_to_base_along_normals + minn_e_disp_rel_to_base_along_normals) / 2. # init_image['per_frame_avg_disp_vt_normals'] = (maxx_e_disp_rel_to_base_vt_normals + minn_e_disp_rel_to_base_vt_normals) / 2. # init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) # init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals'] ) # # normalize ## # base along normals # # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ####### E ####### #### add noise onjts #### to base along normal # rigid objects -> moving; global pose # hwo about we do not add those canonicalization? # and we only need correct contacts? to model # attaction forces? attraction forces # distances? # distances? # k_f = e^{-k\cdot \Vert v_o - v_h\Vert}; --> the proximity value between each pair of points; --> points on the object one object denoising targets --> the distance from hand joint to the object surface; # distance values --> distance values # manipulate the object --> add forces to the object # # manipulate the object --> add forces to the object # # map # a simple case -> map joint points to the object points -> denoise relative positions; realtive positions; joint trajectory; # values that describe the consistency between moving (value negative propotional to distances) * exp(\Vert x_o - x_h\Vert_2) ---> to describe the moving consistency between the hand and the object. # contact map -> or a generatlized contact map -> # add_noise_onjts, add_noise_onjts_single # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints ## pred jts # pred jts # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # space spatial # rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints ) if self.args.add_noise_onjts: ### add noise on joints ### if self.args.use_same_noise_for_rep: ### use same noise for rep ### noise_joints_offset_output = torch.randn_like(joints_offset_sequence) pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_sequence, my_t, noise_joints_offset_output ) if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) else: if not self.args.use_arti_obj: joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints ) pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) else: joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints ) pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(2) elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single # noise_joints_offset_output = torch.randn_like(joints_offset_sequence) pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### joints_offset_sequence, my_t, noise_joints_offset_output ) if not self.args.use_arti_obj: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) else: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1) if self.args.train_enc: pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints # ) # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) # pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... # avg_joints_sequence, my_t, noise_avg_joints_sequence # ) # joints_offset_sequence # joints offset sequence ## noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) print(f"my_t: {my_t}") pert_joints_offset_sequence = self.q_sample( joints_offset_sequence, my_t, noise_joints_offset_sequence ) if self.args.add_noise_onjts_single: noise_joints_offset_sequence = noise_joints_offset_output pert_joints_offset_sequence = pert_joints_offset_output if not self.args.use_arti_obj: if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel): pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) if self.args.finetune_with_cond: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") else: if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel): pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) if self.args.finetune_with_cond: pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") # sv_pert_dict = { # 'joints_offset_sequence': joints_offset_sequence.detach().cpu().numpy(), # 'pert_joints_offset_sequence': pert_joints_offset_sequence.detach().cpu().numpy(), # 'noise_joints_offset_sequence': noise_joints_offset_sequence.detach().cpu().numpy(), # 'joints_offset_sequence_ori': joints_offset_sequence_ori.detach().cpu().numpy(), # 'rhand_joints_ori': rhand_joints.detach().cpu().numpy(), # } # sv_pert_dict_fn = "tot_pert_jts_sequence_dict.npy" # this file @!!!!! # np.save(sv_pert_dict_fn, sv_pert_dict) # print(f"pert data saved to {sv_pert_dict_fn} !!!!") if self.args.rnd_noise: pert_joints_offset_sequence = noise_joints_offset_sequence # pert_avg_joints_sequence = noise_avg_joints_sequence if not self.args.use_arti_obj: ## minus normed base pts here ## # ### normed pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # the strategy of adding noise to the representations # # tot_pert_joint = pert_joints_offset_sequence * std_joints_sequence.unsqueeze(1) + pert_avg_joints_sequence.unsqueeze(1) # np.save("tot_pert_joint.npy", tot_pert_joint.detach().cpu().numpy()) ####### E ####### # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals # # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # # noise_e_disp_rel_to_base_along_normals = torch.randn_like(e_disp_rel_to_base_along_normals) # pert_e_disp_rel_to_base_along_normals = self.q_sample( # e_disp_rel_to_base_along_normals, my_t, noise_e_disp_rel_to_base_along_normals # ) # noise_e_disp_rel_to_base_vt_normals = torch.randn_like(e_disp_rel_to_baes_vt_normals) # pert_e_disp_rel_to_base_vt_normals = self.q_sample( # e_disp_rel_to_baes_vt_normals, my_t, noise_e_disp_rel_to_base_vt_normals # ) ####### E ####### # base_pts_trans, base_normals_trans # # avg_transl, std_transl; avg_rot, std_rot ; avg_theta, std_theta # input_data = { 'base_pts': base_pts, 'base_normals': base_normals, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # 'pert_rhand_joints': pert_normed_rhand_joints, # 'pert_rhand_joints': pert_scaled_rhand_joints, 'rhand_joints': rhand_joints, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # 'avg_joints_sequence': avg_joints_sequence, # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'normed_base_pts': normed_base_pts, 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, ## pert_rel_base_pts_to_joints_for_jts_pred for the bsz x nf x nnj x nnb x 3 --> from base points to joints #### 'pert_rhand_params': pert_rhand_params, 'rhand_betas': init_image['rhand_betas'], 'base_pts_trans': base_pts_trans, 'base_normals_trans': base_normals_trans, 'avg_transl': avg_transl, 'std_transl': std_transl, 'avg_rot': avg_rot, 'std_rot': std_rot, 'avg_theta': avg_theta, 'std_theta': std_theta, # obj_rot, obj_transl # 'obj_rot': obj_rot, 'obj_transl': obj_transl, } # # obj_verts, obj_faces if 'obj_verts' in init_image: input_data.update( { 'obj_verts': init_image['obj_verts'], 'obj_faces': init_image['obj_faces'], } ) # init_image['rhand_betas'] ####### E ####### # primal space denoising -> # input_data.update( # { # 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, # 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, # 'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals, # 'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals, # } # ) ####### E ####### # input # input_data.update(init_image_avg_std_stats) input_data['rhand_joints'] = rhand_joints # normed # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": input_data.update( { 'avg_rel_base_pts_to_rhand_joints': avg_rel_base_pts_to_rhand_joints, 'std_rel_base_pts_to_rhand_joints': std_rel_base_pts_to_rhand_joints, } ) if self.args.train_enc: # # model(input_data, self._scale_timesteps(t).clone()) out_dict = model(input_data, self._scale_timesteps(my_t).clone()) obj_base_pts_feats = out_dict['obj_base_pts_feats'] # obj base pts feats # # noise_obj_base_pts_feats = torch.zeros_like(obj_base_pts_feats) noise_obj_base_pts_feats = torch.randn_like(obj_base_pts_feats) pert_obj_base_pts_feats = self.q_sample( obj_base_pts_feats.permute(1, 0, 2), my_t, noise_obj_base_pts_feats.permute(1, 0, 2) ).permute(1, 0, 2) if self.args.rnd_noise: pert_obj_base_pts_feats = noise_obj_base_pts_feats input_data['pert_obj_base_pts_feats'] = pert_obj_base_pts_feats model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. # from tqdm.auto import tqdm indices = tqdm(indices) for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # progress # # p_sample_with_grad ## p_sample with grid ##s # or for each joints -> the features -> sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample out = sample_fn( model, input_data, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ## const_noise=const_noise, ## # new representation strategies; resolve penerations; resolve penerations ### penetrations for new representations ## ) # diff_basejtsrel # # projection # # projection to get if self.diff_basejtsrel: # basejtrel # params_output = out["params_output"] ## basejtsrle output ## ## ## basejtsrel output ## # 'real_dec_basejtsrel': real_dec_basejtsrel, # 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, # if self.args.pred_joints_offset: # pred # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3 # basejtsrel_seq_latents_sample --> basejtsrel_seq_latents_sample # # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1) # rhand_transl, rhand_rot, rhand_theta # # avg_transl, std_transl; avg_rot, std_rot; avg_theta, std_theta # # rhand_betas = init_image['rhand_betas'].cuda() # bsz x ws x 10 or only bsz x 10 here # # rhand_transl, rhand_rot, rhand_theta # # rhand_transl = init_image['rhand_transl'] # bsz x ws x 3 --> translational vectors # # rhand_rot = init_image['rhand_rot'] # bsz x ws x 3 -> rot var # # rhand_theta = init_image['rhand_theta'] # bsz x ws x 24 -> theta var # if i_idx == len(indices) - 1: # params_output = self.phy_project_pred_verts(rhand_params, input_data['rhand_betas'], input_data['base_pts_trans'], input_data['base_normals_trans'], input_data, nn_iters=0, wccd=True) # [rhand_transl, rhand_rot, rhand_theta], dim=-1 # bsz x ws x (3 + 3 + 24) self.pure_ccd( init_image['rhand_rot'], init_image['rhand_theta'], rhand_betas, init_image['rhand_transl'], input_data, input_data['base_pts_trans'], ) sampled_transl, sampled_rot, sampled_theta = params_output[:, :, :3], params_output[:, :, 3: 6], params_output[:, :, 6:] sampled_transl = sampled_transl * std_transl + avg_transl sampled_rot = sampled_rot * std_rot + avg_rot sampled_theta = sampled_theta * std_theta + avg_theta bsz, nframes = sampled_transl.size()[:2] # mano_layer sampled_rhand_verts, sampled_rhand_joints = self.mano_layer( torch.cat([sampled_rot.view(bsz * nframes, -1), sampled_theta.view(bsz * nframes, -1)], dim=-1), rhand_betas.unsqueeze(1).repeat(1, sampled_transl.size(1), 1).view(bsz * nframes, -1), sampled_transl.view(bsz * nframes, -1) ) sampled_rhand_verts = sampled_rhand_verts.view(bsz, nframes, -1, 3).contiguous() * 0.001 sampled_rhand_joints = sampled_rhand_joints.view(bsz, nframes, -1, 3).contiguous() * 0.001 # params_output: bsz x ws x 3; bsz x ws x 3; bsz x ws x 24 # rhand_betas_exp = rhand_betas.unsqueeze(1).repeat(1, nframes, 1) sampled_rhand_joints = torch.cat( [sampled_transl, sampled_rot, sampled_theta, rhand_betas_exp], dim=-1 # bsz x ws x (3 + 3 + 24 + 10) # for the rhand parameters s# ) # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) # print(f"basejtsrel_seq_latents_sample: {basejtsrel_seq_latents_sample.size()}, normed_base_pts: {normed_base_pts.size()}") # ### pert rel bae pts to rhand joints ### # ### normed base pts ## # if not self.args.use_arti_obj: # pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # else: # pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # print(f"Sampling with pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}") basejtsrel_seq_dec_in_dict = { # finetune_with_cond # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ## # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # rhand joints ## ## and another choice ## another choice ## 'pert_rhand_params': out["params_output"] , } input_data.update(basejtsrel_seq_dec_in_dict) else: # basejtsrel_seq_input_dict = {} basejtsrel_seq_dec_in_dict = {} if self.args.diff_realbasejtsrel_to_joints: # predicted x_{t-1} (normalized) ## # rel to joints dec_joints_offset_output = out['dec_joints_offset_output'] if not self.args.use_arti_obj: ## minus normed base pts here ## # from normed pts and offset outputs # pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) else: pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # predicted x_{t-1} before normalization # sampled_rhand_joints = dec_joints_offset_output * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) realbasejtsrel_to_joints_dec_in_dict = { 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, # bsz x nf x nnj x nnb x 3 ## 'sampled_rhand_joints': sampled_rhand_joints, 'pert_joints_offset_sequence': dec_joints_offset_output, } input_data.update(realbasejtsrel_to_joints_dec_in_dict) if self.diff_realbasejtsrel : # real_dec_basejtsrel = out["real_dec_basejtsrel"] # bsz x nf x nnj x nnb x 3 # # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # add_noise_onjts, add_noise_onjts_single #### add_noise_onjts; add_noise_onjts_single #### if self.args.real_basejtsrel_norm_stra == "std" and (not self.args.add_noise_onjts) and (not self.args.add_noise_onjts_single): real_dec_basejtsrel_pred_sample = real_dec_basejtsrel * std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) + avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) else: # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel.clone() if not self.args.use_arti_obj: real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(1).unsqueeze(1) else: real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(2) # real dec basejtsrel pred sample # # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # # # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) # real pred samples # # if self.args.use_t == 1000: # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # else: # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # bsz x nf x nnj x 3 # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] if self.args.sel_basepts_idx >= 0: sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] else: sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., 0, :] # std joints sequence; # std_joints # sampled_rhand_joints = sampled_rhand_joints * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) real_basejtsrel_dec_in_dict = { # real_dec_basejtsrel # 'pert_rel_base_pts_to_rhand_joints': real_dec_basejtsrel, ## realdecbasejtsrel # 'sampled_rhand_joints': sampled_rhand_joints, } if not self.diff_basejtsrel: real_basejtsrel_dec_in_dict['sampled_rhand_joints'] = sampled_rhand_joints if self.args.train_enc: real_basejtsrel_dec_in_dict['pert_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] input_data.update(real_basejtsrel_dec_in_dict) else: real_basejtsrel_dec_in_dict = {} if self.diff_basejtse: ## seq latents ## seq latents ## # basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"] pert_dec_e_along_normals = out["dec_e_along_normals"] pert_dec_e_vt_normals = out["dec_e_vt_normals"] dec_e_along_normals = pert_dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals'] dec_e_vt_normals = pert_dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals'] ## dec_e_along_normals = torch.clamp(dec_e_along_normals, min=0.) dec_e_vt_normals = torch.clamp(dec_e_vt_normals, min=0.) # scale base ## model constraints and model impacts from object a to object c ## basejtse_seq_input_dict = { 'pert_e_disp_rel_to_base_along_normals': pert_dec_e_along_normals, 'pert_e_disp_rel_to_base_vt_normals': pert_dec_e_vt_normals, 'e_disp_rel_to_base_along_normals': dec_e_along_normals, 'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtse_seq_input_dict) else: basejtse_seq_input_dict = {} # basejtse_seq_dec_in_dict = {} yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # if self.args.train_diff: # set enc to evals # # # print(f"Setitng encoders to eval mode") # model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz ## ## rot2xyz ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # base normals # base normals # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] # rhand_transl, rhand_rot, rhand_theta # rhand_transl = x_start['rhand_transl'] # bsz x ws x 3 --> translational vectors # rhand_rot = x_start['rhand_rot'] # bsz x ws x 3 -> rot var # rhand_theta = x_start['rhand_theta'] # bsz x ws x 24 -> theta var # ### avg, std ### avg_transl, std_transl = torch.mean(rhand_transl, dim=1, keepdim=True), torch.std(rhand_transl, dim=1, keepdim=True) avg_rot, std_rot = torch.mean(rhand_rot, dim=1, keepdim=True), torch.std(rhand_rot, dim=1, keepdim=True) avg_theta, std_theta = torch.mean(rhand_theta, dim=1, keepdim=True), torch.std(rhand_theta, dim=1, keepdim=True) rhand_transl = (rhand_transl - avg_transl) / std_transl # xxx rhand_rot = (rhand_rot - avg_rot) / std_rot # bsz x ws x 3 # rhand_theta= (rhand_theta - avg_theta) / std_theta ### bsz x ws x 24 rhand_params = torch.cat( [rhand_transl, rhand_rot, rhand_theta], dim=-1 # bsz x ws x (3 + 3 + 24) ) if self.args.use_anchors: # rhand_joints: bsz x nf x nn_anchors x 3 # ## rhand verts ## rhand_joints = x_start['rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand # # base_pts, base_normals, rhand_joints # ### rhand verts ## avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## std_joints_sequence = torch.std(rhand_joints, dim=1) if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: # std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # bsz x 3 --> bsz x 1 x 3; std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) # if self.args.use_anchor: # avg_joints_sequence = torch.mean(rhand_joints, dim=1) # normed_base_pts, joints_offset_sequence # joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # bsz x nf x nnj x 3 # normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 if not self.args.use_arti_obj: normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 else: # base_pts: bsz x nf x nnb x 3 # normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # if self.args.jts_sclae_stra == "std": ## jts scale stra ## joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) # normed_base_pts = normed_base_pts / std_joints_sequence if not self.args.use_arti_obj: normed_base_pts = normed_base_pts / std_joints_sequence else: normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] # if 'sampled_base_pts_nearest_obj_pc' in x_start: # ambient_xstart_dict = { # 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], # 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], # } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # # if self.args.wo_e_normalization: # per frame avg disp along normals # # x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) # x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) # x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) # x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) # psatial -> e normalization and centralize? if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel']) x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel']) # normed_base_pts, joints_offset_sequence # ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## ## base pts to rhand joints ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## ## relative joint positions ### ## bsz x # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## # x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # # x_start['per_frame_avg_joints_rel'] = torch # # bsz x ws x nnj x nnb x 3 # # per_frame_avg_joints_rel # # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # rel_base_pts_to_rhand_joints ## rel_base_pts_to_rhand_joints -> joints offset ## Normalization stra 1 --> no normalization for joints sequences ## # normed base pts # rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints # if not self.args.use_arti_obj: ### rel base pts to rhand joints #joints offset sequence # joints offset sequence # joints # joints_offset_sequence - normed_base_pts rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints # else: # rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2) # bsz x nf x nn_joints x nn_base_pts x 3 # # other normalization strategies> if self.args.real_basejtsrel_norm_stra == "mean": rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 # #### rel_base_pts_to_rhand_joints #### # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True)) elif self.args.real_basejtsrel_norm_stra == "std": # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous() avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1 rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # rel_base_pts_to_rhand_joints --> # # print("here using std!!") # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 #### rel_base_pts_to_rhand_joints #### # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) ## rep; motion-to-rep; # construct statistics, normalize values # # joints_scaling_factor = 5. # ''' GET rel and dists ''' ## rep and rhand_joints ##### # rep; motion-to-rep # if self.denoising_stra == "rep": # rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 # denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 ##### E ###### # ### Calculate moving related energies ### # # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## # # denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel'] # # denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( # denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # denormed relative distances # ) ## l2 real base pts # k_f = 1. ## l2 rel base pts to pert rhand joints ## # # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) # ### att_forces ## # att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # # bsz x (ws - 1) x nnj x nnb # # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # # bsz x (ws - 1) x nnj x 3 --> displacements s# # denormed_rhand_joints_disp = rhand_joints[:, 1:, :, :] - rhand_joints[:, :-1, :, :] # # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( # base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 # ) ## signed dist base pts to rhand joints along normals # # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) # dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( # rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 # )) # k_a = 1. # k_b = 1. ### # e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # # (ws - 1) x nnj x nnb # -> dist vt normals # ## # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # ##### E ###### # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] elif self.denoising_stra == "motion_to_rep": # or sdfs. or # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' noise_rhand_params = th.randn_like(rhand_params) pert_rhand_params = self.q_sample( # bsz x ws x (3 + 3 + 24) rhand_params, t, noise_rhand_params ) # # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints # ## pred jts # pred jts # # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ## # rel_base_pts_to_rhand_joints, t, noise_rel_base_pts_to_rhand_joints # ) # # add_noise_onjts, add_noise_onjts_single # ### ==== add noise on joints ==== ### # ### ==== add noise on joints and then use them to calculate the perturbed rel-base-pts-to-rhand-joints ==== ### # if self.args.add_noise_onjts: # add_noise_onjts_single # --> # joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) # noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 ## # joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints # ) # # pert_rel_base_pts_to_rhand_joints: bsz x seq_len x nnj x nnb x 3 --> the rhand-joints; # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) # # elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single # # noise_joints_offset_output = torch.randn_like(joints_offset_sequence) # pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # joints_offset_sequence, t, noise_joints_offset_output # ) # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) # # joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1) # # noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # # noise_rel_base_pts_to_rhand_joints = noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :].repeat(1, 1, 1, normed_base_pts.size(1), 1) # # pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # # joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints # # ) # # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) # if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints # # bsz x ws x nnj x nnb x 3 # maxx_pert_basejtsrel, _ = torch.max(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_pert_basejtsrel, _ = torch.min(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # maxx_basejtsrel, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_basejtsrel, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # # print(f"maxx_pert_basejtsrel: {maxx_pert_basejtsrel}, minn_pert_basejtsrel: {minn_pert_basejtsrel}, maxx_basejtsrel: {maxx_basejtsrel}, minn_basejtsrel: {minn_basejtsrel}") ### Add noise to avg joints sequence # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... avg_joints_sequence, t, noise_avg_joints_sequence ) ### perturbe offset-joints ### # joints_offset_sequence noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) pert_joints_offset_sequence = self.q_sample( joints_offset_sequence, t, noise_joints_offset_sequence ) ### perturbe offset-joints ### # rel bae pts to pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # if self.args.use_jts_pert_realbasejtsrel: # pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_joints_for_jts_pred # noise_rel_base_pts_to_rhand_joints = noise_joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1).contiguous() # if self.args.finetune_with_cond: # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) # x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) # x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) # bigpicture # # x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) # ##### E ###### # if not self.args.wo_e_normalization and self.args.e_normalization_stra == "cent": # # e normalization --> # bsz = e_disp_rel_to_base_along_normals.size(0) # nf, nnj, nnb = e_disp_rel_to_base_along_normals.size()[1:] # high dimensional # # the max value and min value of all values # #bs z x nnf x nnj x nnb --> for the along normals values and vt normals values ## # maxx_e_disp_rel_to_base_along_normals, _ = torch.max(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_along_normals , _ = torch.min(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_vt_normals , _ = torch.max(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # minn_e_disp_rel_to_base_vt_normals , _ = torch.min(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1) # maxx_e_disp_rel_to_base_along_normals = maxx_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_along_normals = minn_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # maxx_e_disp_rel_to_base_vt_normals = maxx_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # minn_e_disp_rel_to_base_vt_normals = minn_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb) # x_start['per_frame_avg_disp_along_normals'] = (maxx_e_disp_rel_to_base_along_normals + minn_e_disp_rel_to_base_along_normals) / 2. # x_start['per_frame_avg_disp_vt_normals'] = (maxx_e_disp_rel_to_base_vt_normals + minn_e_disp_rel_to_base_vt_normals) / 2. # x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) # x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals'] ) # normalize ## # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] ##### E ###### ##### E ###### # # noise_e_disp_rel_to_base_vt_normals # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals 3 # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization; # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # # noise_e_disp_rel_to_base_along_normals = torch.zeros_like(e_disp_rel_to_base_along_normals) # pert_e_disp_rel_to_base_along_normals = self.q_sample( # e_disp_rel_to_base_along_normals, t, noise_e_disp_rel_to_base_along_normals # ) # # noise_e_disp_rel_to_base_vt_normals = torch.zeros_like(e_disp_rel_to_baes_vt_normals) # pert_e_disp_rel_to_base_vt_normals = self.q_sample( # e_disp_rel_to_baes_vt_normals, t, noise_e_disp_rel_to_base_vt_normals # ) ##### E ###### input_data = { 'base_pts': base_pts.clone(), # base pts ### 'base_normals': base_normals.clone(), # base normals ### 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), 'rhand_joints': rhand_joints, 'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ## 'pert_avg_joints_sequence': pert_avg_joints_sequence, # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, 'normed_base_pts': normed_base_pts, 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, 'pert_rhand_params': pert_rhand_params, } # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # # if self.args.train_diff: # with torch.no_grad(): # out_dict = model(input_data, self._scale_timesteps(t).clone()) # else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. terms['rot_mse'] = 0. ### diff_jts ### # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_jts: ### Sample for perturbed joints seq latents ### clean_joint_seq_latents = out_dict["joint_seq_output"] noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) if self.args.const_noise: noise_joint_seq_latents = noise_joint_seq_latents[0].unsqueeze(0).repeat(noise_joint_seq_latents.size(0), 1, 1).contiguous() pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() ### Sample for perturbed joints seq latents ### dec_in_dict['joints_seq_latents'] = pert_joint_seq_latents dec_in_dict['joints_seq_latents_enc'] = clean_joint_seq_latents if self.args.kl_weights > 0. and "joint_seq_output_mean" in out_dict and not self.args.train_diff: # clean_joint_seq_latents: seq_len x bs x d # log_p_joints_seq = model_util.standard_normal_logprob(clean_joint_seq_latents) log_p_joints_seq = log_p_joints_seq.permute(1, 0, 2).contiguous() # log_p_joints_seq = log_p_joints_seq.sum(dim=-1).mean(dim=-1) # log_p_joints_seq entropy_joints_seq = model_util.gaussian_entropy(out_dict['joint_seq_output_logvar'].permute(1, 2, 0)).mean(dim=-1) loss_prior_joints_seq = (- log_p_joints_seq - entropy_joints_seq) KL_loss += loss_prior_joints_seq # the dimension of latents ? ## if self.args.diff_realbasejtsrel_to_joints: joints_offset_output = out_dict["joints_offset_output"] if self.args.pred_diff_noise: jts_pred_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 else: jts_pred_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # terms['jts_pred_loss'] = jts_pred_loss terms['rot_mse'] += jts_pred_loss if self.diff_realbasejtsrel: # real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints if self.args.pred_diff_noise and not self.args.train_enc: # print(f"here predicting diff_noise...") if self.args.use_jts_pert_realbasejtsrel: # print(f"use_jts_pert_realbasejtsrel!!!") jts_pred_loss = torch.sum(( real_dec_basejtsrel[:, :, :, 0:1, :] - noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # jts_pred_loss = torch.sum(( # real_dec_basejtsrel - noise_joints_offset_sequence # ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) else: jts_pred_loss = torch.sum(( real_dec_basejtsrel - noise_rel_base_pts_to_rhand_joints ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) else: jts_pred_loss = torch.sum(( real_dec_basejtsrel - rel_base_pts_to_rhand_joints ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['jts_pred_loss'] = jts_pred_loss terms['rot_mse'] += jts_pred_loss if self.args.train_enc: obj_base_pts_feats = out_dict['obj_base_pts_feats'].detach() noise_obj_base_pts_feats = th.randn_like(obj_base_pts_feats) ## bsz x ws x nnjts x 3 -> for each joint point # pert_obj_base_pts_feats = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### obj_base_pts_feats.permute(1, 0, 2), t, noise_obj_base_pts_feats.permute(1, 0, 2) ).permute(1, 0, 2) # seq_len x bsz x nn_feats_dim # dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats( pert_obj_base_pts_feats, self._scale_timesteps(t).clone()) if self.args.pred_diff_noise: obj_base_pts_feats_denoising_loss = torch.sum( (dec_obj_base_pts_feats - noise_obj_base_pts_feats) ** 2, dim=-1 ) / noise_obj_base_pts_feats.size(-1) obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1) else: obj_base_pts_feats_denoising_loss = torch.sum( (dec_obj_base_pts_feats - obj_base_pts_feats) ** 2, dim=-1 ) / obj_base_pts_feats.size(-1) obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1) terms['jts_latent_denoising_loss'] = obj_base_pts_feats_denoising_loss terms['rot_mse'] += obj_base_pts_feats_denoising_loss if self.diff_basejtsrel: # if 'basejtsrel_output' in out_dict: # basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous() # avg_jts_outputs = out_dict['avg_jts_outputs'] # # print(f"basejtsrel_output: {basejtsrel_output.size()}, noise_rel_base_pts_to_rhand_joints: {noise_rel_base_pts_to_rhand_joints.size()}, rel_base_pts_to_rhand_joints: {rel_base_pts_to_rhand_joints.size()}") # if self.args.pred_diff_noise: # basejtsrel_denoising_loss = torch.sum((basejtsrel_output - noise_rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) # avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # basejtsrel_denoising_loss = torch.sum((basejtsrel_output - rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) # avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # joints output dict from the out_dict ### we should add a data-loadingf joints_offset_output = out_dict['params_output'] if self.args.pred_diff_noise: basejtsrel_denoising_loss = torch.sum((joints_offset_output - noise_rhand_params) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 --> mean and mena over dim=-1 else: # # basejtsrel denoising losses ## basejtsrel_denoising_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # # if 'avg_jts_outputs' in out_dict: # avg jts outputs ## # avg_jts_outputs = out_dict['avg_jts_outputs'] # if self.args.pred_diff_noise: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.zeros_like(basejtsrel_denoising_loss) terms['basejtrel_denoising_loss'] = basejtsrel_denoising_loss # terms['avgjts_denoising_loss'] = avgjts_denoising_loss # # terms['rot_mse'] += basejtsrel_denoising_loss # + avgjts_denoising_loss # jts denoising ## if self.diff_basejtse: ### Sample for perturbed basejtsrel seq latents ### dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] # pred_e_along_vt_loss, # dec_e_along_normals; dec_e_vt_normals; # # bszx xnnj x nnb x 1 # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals 3 if self.args.pred_diff_noise: ## predict # predict e along and vt nromals ## pred_e_along_normals_loss = ((dec_e_along_normals - noise_e_disp_rel_to_base_along_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) pred_e_along_vt_loss = ((dec_e_vt_normals - noise_e_disp_rel_to_base_vt_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) terms['basejtse_along_normals_pred_loss'] = pred_e_along_normals_loss terms['basejtse_vt_normals_pred_loss'] = pred_e_along_vt_loss terms['rot_mse'] += pred_e_along_normals_loss + pred_e_along_vt_loss # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class GaussianDiffusionV7: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64)) # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ### variance xxx noise ### ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance_cond( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ # p_mean_varaince # if model_kwargs is None: model_kwargs = {} B = x['base_pts'].shape[0] assert t.shape == (B,) # print(f"t_shape: {t.shape}", "base_pts", x['base_pts'].size()) input_data = x out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} real_basejtsrel_seq_rt_dict = {} basejtsrel_seq_rt_dict = {} realbasejtsrel_to_joints_rt_dict = {} model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_realbasejtsrel and self.diff_basejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output # print(f"basejtsrel_output: {basejtsrel_output.size()}") # if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) # t_item = t[0].item() # alpha = self.var_sched.alphas[t_item] # alpha_bar = self.var_sched.alpha_bars[t_item] # sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # c0 = 1.0 / torch.sqrt(alpha) # c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # beta = self.var_sched.betas[[t[0].item()] * bsz] # z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) # basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta # else: # basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # combine those two things # if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: # add noise onjts # if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints # jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # jts_fr_basepts = jts_fr_basepts.mean(dim=-2) jts_fr_basepts = pert_rel_base_pts_outputs # pert # score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 score_jts_fr_basepts = real_dec_basejtsrel[..., self.args.sel_basepts_idx, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] # combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts)[..., -5:, :] # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + score_jts_fr_basepts[..., -5:, :] * 0.3 # combined_socre = combined_socre * 0.2 + score_jts_fr_basepts * 0.8 combined_socre = combined_socre * 0.1 + score_jts_fr_basepts * 0.9 # combined_socre = combined_socre * 0.05 + score_jts_fr_basepts * 0.95 # combined_socre = combined_socre * 0.5 + score_jts_fr_basepts * 0.5 # combined_socre = combined_socre # combined_socre = score_jts_fr_basepts else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts combined_socre = score_jts # not cmb finger # # combined_socre = score_jts_fr_basepts # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, score_jts_fr_basepts: {score_jts_fr_basepts.size()}, combined_socre: {combined_socre.size()}, score_jts: {score_jts.size()}") if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. # use_var_sched -> # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # print(f"dec_jts_fr_basepts: {dec_jts_fr_basepts.size()}, normed_base_pts: ", x['normed_base_pts'].size(), "real_dec_basejtsrel:", real_dec_basejtsrel.size()) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, denormed_rel_base_pts_to_rhand_joints: {denormed_rel_base_pts_to_rhand_joints.size()}, jts_fr_basepts: {jts_fr_basepts.size()}") elif self.args.add_noise_onjts_single: # add noise on single joint jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] socre_jts_fr_basepts = dec_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # strategy 1 --> conditioning # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts)[..., -5:, :] # strategy 2 --> linear interpolation # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + socre_jts_fr_basepts[..., -5:, :] * 0.3 combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.5 + socre_jts_fr_basepts[..., -5:, :] * 0.5 else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = socre_jts_fr_basepts if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(combined_socre) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (pert_rel_base_pts_outputs - c1 * combined_socre) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: dec_jts_fr_basepts = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: # raise ValueError(f"Add noise directly --- not implemented yet") # # input_data # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts_fr_basepts if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_jts_fr_basepts } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints raise ValueError(f"Trian enc --- Not implemented yet") pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) # else: # real_basejtsrel_seq_rt_dict = {} # basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel and self.args.diff_realbasejtsrel_to_joints: pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output_from_rel'] # joints offset output # score_jts_fr_rel = dec_joints_offset_output # # pert joints offset sequence # # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_rel # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_rel alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * score_jts) + sigma * z # theta ### realjtsrel_to_joints and joints only ## realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_joints_offset_output } # else: # realbasejtsrel_to_joints_rt_dict = {} # basejtsrel_seq_rt_dict = {} # rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) # rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### # B = x['base_pts'].shape[0] B = t.shape[0] assert t.shape == (B,) # input_data = x ## dec_out and out ## ## output dict ## # out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## posterior_log_variance_clipped, posterior_variance ## model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # pmean variance if self.diff_jts: # x_t ## joints seq latents ## pert_joints_seq_latents = x['joints_seq_latents'] # x_t pred_clean_joint_seq_latents = out_dict["joints_seq_latents"] ## if self.args.pred_diff_noise: ## eps -> estimated noises ## t > for added joints latents ## pred_clean_joint_seq_latents = self._predict_xstart_from_eps(pert_joints_seq_latents.permute(1, 0, 2), t=t, eps=pred_clean_joint_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # seq x bs x d # # minn_pert_joints_seq_latents, _ = torch.min(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # maxx_pert_joints_seq_latents, _ = torch.max(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # print(f"pred minn_pert_joints_latents: {minn_pert_joints_seq_latents[:10]}, pred maxx_pert_joints_seq_latents: {maxx_pert_joints_seq_latents[:10]}") ## out_dict["joint_seq_output"] = model.model.dec_jts_only_fr_latents(pred_clean_joint_seq_latents)["joint_seq_output"] ## joints seq latents mean # # pred_clean_joint_seq_latents = pert_joints_seq_latents ## joints seq latents mean # joints_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_clean_joint_seq_latents.permute(1, 0, 2), x_t=pert_joints_seq_latents.permute(1, 0, 2), t=t ) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_mean = joints_seq_latents_mean.permute(1, 0, 2) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_variance = _extract_into_tensor(model_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joints_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # joint seq output # # joint seq output # joint_seq_output = out_dict["joint_seq_output"] jts_seq_rt_dict = { ### joints seq latents ### "joints_seq_latents_mean": joints_seq_latents_mean, "joints_seq_latents_variance": joints_seq_latents_variance, "joints_seq_latents_log_variance": joints_seq_latents_log_variance, ### decoded output values ### "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} # /data1/sim/mdm/save/predoffset_stdscale_bsz_10_pred_diff_realbaesjtsrel_nonorm_std_for_norm_train_enc_with_diff_latents_prediffnoise_none_norm_rel_rel_to_jts_/model000007000.pt if self.args.diff_realbasejtsrel_to_joints: pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output'] # joints offset output # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> beta, z c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * dec_joints_offset_output) + sigma * z # theta ## use the predicted latents and pert_latents for the seq latents prediction ## dec_joints_offset_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=dec_joints_offset_output, x_t=pert_joints_offset_output, t=t ) ## from model_variance to basejtsrel_seq_latents ### dec_joints_offset_output_variance = _extract_into_tensor(model_variance, t, dec_joints_offset_output.shape) dec_joints_offset_output_log_variance = _extract_into_tensor(model_log_variance, t, dec_joints_offset_output.shape) realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) dec_jts_fr_basepts = real_dec_basejtsrel # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # # real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta ## dec jts fr base pts ## # dec jts fr # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) elif self.args.add_noise_onjts_single: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) # dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(pert_rel_base_pts_to_rhand_joints) if t_item > 0 else torch.zeros_like(pert_rel_base_pts_to_rhand_joints) # z = torch.zeros_like(pert_rel_base_pts_to_rhand_joints) real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta # dec_jts_fr_basepts = real_dec_basejtsrel + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # get dec_jts fr basepts # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # repeated basepts # real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # real dec # else: # x_{t-1} real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) ## use the predicted latents and pert_latents for the seq latents prediction ## real_dec_basejtsrel_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=real_dec_basejtsrel, x_t=pert_rel_base_pts_to_rhand_joints, t=t ) ## from model_variance to basejtsrel_seq_latents ### real_dec_basejtsrel_variance = _extract_into_tensor(model_variance, t, real_dec_basejtsrel.shape) real_dec_basejtsrel_log_variance = _extract_into_tensor(model_log_variance, t, real_dec_basejtsrel.shape) # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, "real_dec_basejtsrel_mean": real_dec_basejtsrel_mean, "real_dec_basejtsrel_variance": real_dec_basejtsrel_variance, "real_dec_basejtsrel_log_variance": real_dec_basejtsrel_log_variance, } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) else: real_basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel: # pert joints quants latents # pert_joints_quants_latents = x['pert_joints_quants_latents'] # get the perturbed joints quants latents # # denoised latents # # self._scale_timesteps(t).clone() # denoised_latents: seq_lent x bsz x latent_dim # we use denoised latents # denoised_latents = model.model.denoising_joint_quants_feats(pert_joints_quants_latents, self._scale_timesteps(t).clone()) # basejtsrel seq rt dict # # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # pert_joints_quants_latents # basejtsrel_output = out_dict['joints_offset_output'] if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) bsz = B t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # x_t = traj[t] # beta = self.var_sched.betas[[t[0].item()] * bsz] # if mask is not None: # x_t = x_t * mask # e_theta = self.net(x_t, beta=beta, context=context) # denoising z = torch.randn_like(pert_joints_quants_latents) if t_item > 0 else torch.zeros_like(pert_joints_quants_latents) denoised_latents = c0 * (pert_joints_quants_latents - c1 * denoised_latents) + sigma * z # theta else: denoised_latents = self._predict_xstart_from_eps(pert_joints_quants_latents, t=t, eps=denoised_latents) if self.args.train_enc: # joints_quants_latents # pred joint quants pred_joint_quants = model.model.pred_joint_quants_from_latent_feats(x['joints_quants_latents']) # joints_quants_latents for quants_latents # else: pred_joint_quants = model.model.pred_joint_quants_from_latent_feats(denoised_latents) # basejtsrel_output = out_dict["basejtsrel_output"] # print(f"basejtsrel_output: {basejtsrel_output.size()}") basejtsrel_seq_rt_dict = { ### basejtsrel seq latents ### # "avg_jts_outputs": avg_jts_outputs, # "basejtsrel_output_variance": basejtsrel_output_variance, # "basejtsrel_output_log_variance": basejtsrel_output_log_variance, # # "avg_jts_outputs_variance": avg_jts_outputs_variance, # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance, # "basejtsrel_output": basejtsrel_output, # pred joint quants -> 'pred_joint_quants': pred_joint_quants, 'denoised_latents': denoised_latents, } else: basejtsrel_seq_rt_dict = {} if self.diff_basejtse: dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] pert_e_along_normals = x['pert_e_disp_rel_to_base_along_normals'] pert_e_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # basejtsrel_output = out_dict['joints_offset_output'] if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = dec_e_along_normals.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_e_along_normals) if t_item > 0 else torch.zeros_like(dec_e_along_normals) dec_e_along_normals = c0 * (pert_e_along_normals - c1 * dec_e_along_normals) + sigma * z # theta z_vt_normals = torch.randn_like(dec_e_vt_normals) if t_item > 0 else torch.zeros_like(dec_e_vt_normals) dec_e_vt_normals = c0 * (pert_e_vt_normals - c1 * dec_e_vt_normals) + sigma * z_vt_normals # theta else: dec_e_along_normals = self._predict_xstart_from_eps(pert_e_along_normals, t=t, eps=dec_e_along_normals) dec_e_vt_normals = self._predict_xstart_from_eps(pert_e_vt_normals, t=t, eps=dec_e_vt_normals) # base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### # pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] # ### q-sampled latent mean here ### # basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( # x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t # ) # basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) # basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # # base_jts_e_feats = out_dict["base_jts_e_feats"] # dec_e_along_normals = out_dict["dec_e_along_normals"] # dec_e_vt_normals = out_dict["dec_e_vt_normals"] basejtse_seq_rt_dict = { ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_seq_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). # """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def judge_activated(self, target_setting): if target_setting: return 1 else: return 0 def p_sample( ## p sample ## self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. # gaussian diffusion # :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ multi_activated = ( self.judge_activated(self.diff_jts) + self.judge_activated(self.args.diff_realbasejtsrel_to_joints) + self.judge_activated(self.diff_realbasejtsrel) + self.judge_activated(self.diff_basejtsrel) + self.judge_activated(self.diff_basejtse) ) > 1.5 if multi_activated: # print(f"Multiple settings activated! Using combined sampling...") p_mena_variance_fn = self.p_mean_variance_cond else: # print(f"Single setting activated! Using single sampling...") p_mena_variance_fn = self.p_mean_variance out = p_mena_variance_fn( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} if self.diff_jts: # bsz x ws x nnj x nnb x 3 # joints_seq_latents_noise = th.randn_like(x['joints_seq_latents']) # print('const_noise', const_noise) # seq x bsz x latent_dim # if const_noise: print(f"joints latents hape, ", x['joints_seq_latents'].shape) joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1) # joints_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1))) ) # no noise when t == 0 # bsz x nseq x dim # #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... # joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2) # nseq x bsz x dim # joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # #### ==== joints_seq_latents ===== #### joint_seq_output = out["joint_seq_output"] jts_seq_rt_dict = { "joints_seq_latents_sample": joints_seq_latents_sample, "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.args.diff_realbasejtsrel_to_joints: ## args.pred to joints # dec_joints_offset_output = realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': out['dec_joints_offset_output'] } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: if self.args.train_enc or ( self.args.pred_diff_noise and self.args.use_var_sched): real_dec_basejtsrel = out['real_dec_basejtsrel'] else: real_dec_basejtsrel_noise = th.randn_like(out['real_dec_basejtsrel']) if const_noise: real_dec_basejtsrel_noise = real_dec_basejtsrel_noise[[0]].repeat(out['real_dec_basejtsrel'].shape[0], 1, 1, 1, 1) real_dec_basejtsrel_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['real_dec_basejtsrel'].shape) - 1))) ) real_dec_basejtsrel = out["real_dec_basejtsrel_mean"] + real_dec_basejtsrel_nonzero_mask * th.exp(0.5 * out["real_dec_basejtsrel_log_variance"]) * real_dec_basejtsrel_noise real_basejtsrel_rt_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, } if self.args.train_enc: real_basejtsrel_rt_dict['dec_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] else: real_basejtsrel_rt_dict = {} if self.diff_basejtsrel: basejtsrel_rt_dict = out else: basejtsrel_rt_dict = {} if self.diff_basejtse: # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### # ### rel_base_pts_outputs mask ### # basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # # print('const_noise', const_noise) # if const_noise: # basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) # basejtse_seq_latents_nonzero_mask = ( # (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) # ) # no noise when t == 0 # #### ==== basejtsrel_seq_latents ===== #### # basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) # basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) # #### ==== basejtsrel_seq_latents ===== #### # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## dec_e_vt_normals = out["dec_e_vt_normals"] basejtse_rt_dict = { # "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) rt_dict.update(real_basejtsrel_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final # score # socre p_sample_loop_progressive # def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, # denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # p_sample_loop_progresive # """ # p_sample loop progressive # Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ### sample progressive ### ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] # rhand_joints --> rhand_joints for th # rhand_joints = init_image['gt_rhand_joints'] ### init_image ### obj_rot = init_image['obj_rot'] # bsz x nf x 3 x 3 --> obj_rot obj_transl = init_image['obj_transl'] # obj_rot; obj_trans # # obj_transl: bsz x nf x 3 --> transl of obj base canon_rhand_joints = rhand_joints.clone() canon_base_pts = base_pts.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1).contiguous() # bsz x nf x nn_base_pts x 3 # canon_base_normals = base_normals.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1).contiguous() # bsz x nf x nn_base_pts x 3 # # base-pts base_pts = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # base_normals = torch.matmul(base_normals.unsqueeze(1), obj_rot) # bsz x nf x nn_base_pts x 3 # if len(self.args.predicted_info_fn) > 0: print(f"Joints are not transformed here!") rhand_joints = rhand_joints # obj_rot, obj_trans # else: rhand_joints = torch.matmul(rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # gt_rhand_joints = init_image['gt_rhand_joints'] canon_gt_rhand_joints = gt_rhand_joints.clone() gt_rhand_joints = torch.matmul(gt_rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # # base_pts = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # base_normals = torch.matmul(base_normals.unsqueeze(1), obj_rot) # bsz x nf x nn_base_pts x 3 # # rhand_joints = torch.matmul(rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # calculate_disp_quants_batched # velocity; and # for each of them --> (bsz x (nf - 1) x nn_joint) # # joints quantitiess # (joints, base_pts_trans, canon_joints, canon_base_normals) if self.args.use_temporal_rep_v2: disp_base_pts, dist_jts_to_base_pts, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched_v2(rhand_joints, base_pts, canon_rhand_joints, canon_base_normals) # print(f"disp_base_pts: {disp_base_pts.size()}, dist_jts_to_base_pts: {dist_jts_to_base_pts.size()}, dist_disp_along_dir: {dist_disp_along_dir.size()}, dist_disp_vt_dir: {dist_disp_vt_dir.size()},") joints_quants = torch.cat( [disp_base_pts, dist_jts_to_base_pts.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) joints_quants_for_pred = torch.cat( [dist_jts_to_base_pts.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) # canon_gt_rhand_joints disp_base_pts_gt, dist_jts_to_base_pts_gt, dist_disp_along_dir_gt, dist_disp_vt_dir_gt = calculate_disp_quants_batched_v2(gt_rhand_joints, base_pts, canon_gt_rhand_joints, canon_base_normals) diff_pert_gt_joints_to_base_pts_disp = torch.mean( (dist_jts_to_base_pts - dist_jts_to_base_pts_gt) ** 2 ) diff_pert_gt_dist_disp_along_dir = torch.mean( (dist_disp_along_dir - dist_disp_along_dir_gt) ** 2 ) diff_pert_gt_dist_disp_vt_dir = torch.mean( (dist_disp_vt_dir - dist_disp_vt_dir_gt) ** 2 ) print(f"[use_temporal_rep_v2] diff_pert_gt_joints_to_base_pts_disp: {diff_pert_gt_joints_to_base_pts_disp.item()}, diff_pert_gt_dist_disp_along_dir: {diff_pert_gt_dist_disp_along_dir.item()}, diff_pert_gt_dist_disp_vt_dir: {diff_pert_gt_dist_disp_vt_dir.item()}") gt_dist_joints_to_base_pts_disp = dist_jts_to_base_pts_gt gt_dist_disp_along_dir = dist_disp_along_dir_gt gt_dist_disp_vt_dir = dist_disp_vt_dir_gt else: # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(rhand_joints, base_pts) # # joints_quants: bsz x (nf - 1) x nnjoints x 3 --> joints_quants # # joints_quants = torch.cat( # why we cannot predict this well ? # # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 # ) # joints_quants_pred = joints_quants.clone() # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # calculate_disp_quants_batched # rhand_joints -> rhand_joints # for each of them --> (bsz x (nf - 1) x nn_joint) # # joints quantitiess dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(rhand_joints, base_pts) gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants_batched(gt_rhand_joints, base_pts) # joints_quants: bsz x (nf - 1) x nnjoints x 3 --> joints_quants # joints_quants = torch.cat( [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) joints_quants_for_pred = joints_quants.clone() # gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir # diff_pert_gt_joints_to_base_pts_disp = torch.mean( (dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2 ) diff_pert_gt_dist_disp_along_dir = torch.mean( (dist_disp_along_dir - gt_dist_disp_along_dir) ** 2 ) diff_pert_gt_dist_disp_vt_dir = torch.mean( (dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2 ) ### diff pert gt joints to base pts disp; diff_pert_gt_dist_disp_along_dir; diff_pert_gt_dist_disp_vt_dir ### print(f"diff_pert_gt_joints_to_base_pts_disp: {diff_pert_gt_joints_to_base_pts_disp.item()}, diff_pert_gt_dist_disp_along_dir: {diff_pert_gt_dist_disp_along_dir.item()}, diff_pert_gt_dist_disp_vt_dir: {diff_pert_gt_dist_disp_vt_dir.item()}") # joints_quants = torch.cat( # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1)], dim=-1 # ) # joints_quants = torch.cat( # [dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 # ) # x = joints_quants input_data = { 'joints_quants': joints_quants } ### without e normalization ### # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") # joints_s my_t = th.tensor([indices[0]] * shape[0], device=device).cuda() ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(my_t).clone()) ### get model output dictionary ### joints_quants_latents = out_dict['joints_quants_latents'] pred_joint_quants = out_dict['joints_quants_output'] # joints_quants_latents noise_joints_quants_latents = th.randn_like(joints_quants_latents) pert_joints_quants_latents = self.q_sample( # # joints_quants_latents joints_quants_latents.permute(1, 0, 2), my_t, noise_joints_quants_latents.permute(1, 0, 2) # ).permute(1, 0, 2) # seq_len x bsz x latent_dim # # input_data = { # 'base_pts': base_pts, # 'base_normals': base_normals, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # # 'pert_rhand_joints': pert_normed_rhand_joints, # # 'pert_rhand_joints': pert_scaled_rhand_joints, # 'rhand_joints': rhand_joints, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # # 'avg_joints_sequence': avg_joints_sequence, # # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence # # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'pert_joints_offset_sequence': pert_joints_offset_sequence, # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # # 'pert_joints_offset_sequence': pert_joints_offset_sequence, # 'normed_base_pts': normed_base_pts, # 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, ## pert_rel_base_pts_to_joints_for_jts_pred for the bsz x nf x nnj x nnb x 3 --> from base points to joints #### 'pert_joints_quants_latents': pert_joints_quants_latents, # add perturbed information # 'joints_quants_latents': joints_quants_latents, # clean latent vectors # } # if self.args.train_enc: # # diff_joints_quants_pred_joints_quants_dist = torch.sum( # (joints_quants[..., 0:1] - pred_joint_quants[..., 0:1]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_dist: {diff_joints_quants_pred_joints_quants_dist.item()}") ### diff_joints_quants_pred_joints_quants # diff_joints_quants_pred_joints_quants_along_normals = torch.sum( # (joints_quants[..., 1:2] - pred_joint_quants[..., 1:2]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_along_normals: {diff_joints_quants_pred_joints_quants_along_normals.item()}") # diff_joints_quants_pred_joints_quants_vt_normals = torch.sum( # (joints_quants[..., 2:] - pred_joint_quants[..., 2:]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_vt_normals: {diff_joints_quants_pred_joints_quants_vt_normals.item()}") # input_data.update( # {'pred_joint_quants': pred_joint_quants} # ) # for t in range(len(indices)): # yield input_data # return # primal space denoising -> # input_data.update( # { # 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, # 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, # 'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals, # 'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals, # } # ) # input # # input_data.update(init_image_avg_std_stats) # input_data['rhand_joints'] = rhand_joints # normed # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints # if self.args.real_basejtsrel_norm_stra == "std": # input_data.update( # { # 'avg_rel_base_pts_to_rhand_joints': avg_rel_base_pts_to_rhand_joints, # 'std_rel_base_pts_to_rhand_joints': std_rel_base_pts_to_rhand_joints, # } # ) if self.args.train_enc: # # # model(input_data, self._scale_timesteps(t).clone()) # out_dict = model(input_data, self._scale_timesteps(my_t).clone()) # obj_base_pts_feats = out_dict['obj_base_pts_feats'] # obj base pts feats # # # noise_obj_base_pts_feats = torch.zeros_like(obj_base_pts_feats) # noise_obj_base_pts_feats = torch.randn_like(obj_base_pts_feats) # pert_obj_base_pts_feats = self.q_sample( # obj_base_pts_feats.permute(1, 0, 2), my_t, noise_obj_base_pts_feats.permute(1, 0, 2) # ).permute(1, 0, 2) # if self.args.rnd_noise: # pert_obj_base_pts_feats = noise_obj_base_pts_feats # input_data['pert_obj_base_pts_feats'] = pert_obj_base_pts_feats # joints_quants_latents input_data['pert_joints_quants_latents'] = joints_quants_latents model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. # from tqdm.auto import tqdm indices = tqdm(indices) for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # progress # # p_sample_with_grad ## p_sample with grid ##s # or for each joints -> the features -> sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample out = sample_fn( model, input_data, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ## const_noise=const_noise, ## # new representation strategies; resolve penerations; resolve penerations ### penetrations for new representations ## ) if self.diff_basejtsrel: # 'pred_joint_quants': pred_joint_quants, # 'denoised_latents': denoised_latents, pred_joint_quants = out['pred_joint_quants'] denoised_latents = out['denoised_latents'] # denoised latents # # basejtsrel_seq_latents_sample = out["basejtsrel_seq_latents_sample"] ## basejtsrle output ## ## ## basejtsrel output ## # 'real_dec_basejtsrel': real_dec_basejtsrel, # 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, # if self.args.pred_joints_offset: # pred # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3 # basejtsrel_seq_latents_sample --> basejtsrel_seq_latents_sample # # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1) # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) diff_joints_quants_pred_joints_quants_dist = torch.sum( (joints_quants_for_pred[..., 0:1] - pred_joint_quants[..., 0:1]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_dist: {diff_joints_quants_pred_joints_quants_dist.item()}") ### diff_joints_quants_pred_joints_quants diff_joints_quants_pred_joints_quants_along_normals = torch.sum( (joints_quants_for_pred[..., 1:2] - pred_joint_quants[..., 1:2]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_along_normals: {diff_joints_quants_pred_joints_quants_along_normals.item()}") diff_joints_quants_pred_joints_quants_vt_normals = torch.sum( (joints_quants_for_pred[..., 2:] - pred_joint_quants[..., 2:]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_vt_normals: {diff_joints_quants_pred_joints_quants_vt_normals.item()}") # gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir diff_joints_quants_pred_joints_quants_dist_gt = torch.sum( (gt_dist_joints_to_base_pts_disp.unsqueeze(-1) - pred_joint_quants[..., 0:1]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_dist_gt: {diff_joints_quants_pred_joints_quants_dist_gt.item()}") ### diff_joints_quants_pred_joints_quants diff_joints_quants_pred_joints_quants_along_normals_gt = torch.sum( (gt_dist_disp_along_dir.unsqueeze(-1) - pred_joint_quants[..., 1:2]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_along_normals_gt: {diff_joints_quants_pred_joints_quants_along_normals_gt.item()}") diff_joints_quants_pred_joints_quants_vt_normals_gt = torch.sum( (gt_dist_disp_vt_dir.unsqueeze(-1) - pred_joint_quants[..., 2:]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_vt_normals_gt: {diff_joints_quants_pred_joints_quants_vt_normals_gt.item()}") basejtsrel_seq_dec_in_dict = { # finetune_with_cond # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ## # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # rhand joints ## ## and another choice ## another choice ## # 'pert_joints_offset_sequence': out["basejtsrel_seq_latents_sample"], 'pert_joints_quants_latents': denoised_latents, 'pred_joint_quants': pred_joint_quants, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtsrel_seq_dec_in_dict) else: # basejtsrel_seq_input_dict = {} basejtsrel_seq_dec_in_dict = {} if self.args.diff_realbasejtsrel_to_joints: # predicted x_{t-1} (normalized) ## # rel to joints dec_joints_offset_output = out['dec_joints_offset_output'] ## minus normed base pts here ## # from normed pts and offset outputs # pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # predicted x_{t-1} before normalization # sampled_rhand_joints = dec_joints_offset_output * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) realbasejtsrel_to_joints_dec_in_dict = { 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, # bsz x nf x nnj x nnb x 3 ## 'sampled_rhand_joints': sampled_rhand_joints, 'pert_joints_offset_sequence': dec_joints_offset_output, } input_data.update(realbasejtsrel_to_joints_dec_in_dict) if self.diff_realbasejtsrel : # real_dec_basejtsrel = out["real_dec_basejtsrel"] # bsz x nf x nnj x nnb x 3 # # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # add_noise_onjts, add_noise_onjts_single #### add_noise_onjts; add_noise_onjts_single #### if self.args.real_basejtsrel_norm_stra == "std" and (not self.args.add_noise_onjts) and (not self.args.add_noise_onjts_single): real_dec_basejtsrel_pred_sample = real_dec_basejtsrel * std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) + avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) else: # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel.clone() real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(1).unsqueeze(1) # real dec basejtsrel pred sample # # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # # # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) # real pred samples # if self.args.use_t == 1000: sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) else: sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # bsz x nf x nnj x 3 # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., 0, :] # std joints sequence; # std_joints # sampled_rhand_joints = sampled_rhand_joints * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) real_basejtsrel_dec_in_dict = { # real_dec_basejtsrel # 'pert_rel_base_pts_to_rhand_joints': real_dec_basejtsrel, ## realdecbasejtsrel # 'sampled_rhand_joints': sampled_rhand_joints, } if not self.diff_basejtsrel: real_basejtsrel_dec_in_dict['sampled_rhand_joints'] = sampled_rhand_joints if self.args.train_enc: real_basejtsrel_dec_in_dict['pert_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] input_data.update(real_basejtsrel_dec_in_dict) else: real_basejtsrel_dec_in_dict = {} if self.diff_basejtse: ## seq latents ## seq latents ## # basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"] pert_dec_e_along_normals = out["dec_e_along_normals"] pert_dec_e_vt_normals = out["dec_e_vt_normals"] dec_e_along_normals = pert_dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals'] dec_e_vt_normals = pert_dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals'] ## dec_e_along_normals = torch.clamp(dec_e_along_normals, min=0.) dec_e_vt_normals = torch.clamp(dec_e_vt_normals, min=0.) # scale base ## model constraints and model impacts from object a to object c ## basejtse_seq_input_dict = { 'pert_e_disp_rel_to_base_along_normals': pert_dec_e_along_normals, 'pert_e_disp_rel_to_base_vt_normals': pert_dec_e_vt_normals, 'e_disp_rel_to_base_along_normals': dec_e_along_normals, 'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtse_seq_input_dict) else: basejtse_seq_input_dict = {} # basejtse_seq_dec_in_dict = {} yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # if self.args.train_diff: # set enc to evals # # # print(f"Setitng encoders to eval mode") # model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # training lossesss # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # base normals # base normals # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] obj_rot = x_start['obj_rot'] # bsz x nf x 3 x 3 --> obj_rot obj_transl = x_start['obj_transl'] # obj_rot; obj_trans # # obj_transl: bsz x nf x 3 --> transl of obj base canon_rhand_joints = rhand_joints.clone() canon_base_pts = base_pts.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1).contiguous() # bsz x nf x nn_base_pts x 3 # canon_base_normals = base_normals.unsqueeze(1).repeat(1, rhand_joints.size(1), 1, 1).contiguous() # bsz x nf x nn_base_pts x 3 # # base_pts = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # base_normals = torch.matmul(base_normals.unsqueeze(1), obj_rot) # bsz x nf x nn_base_pts x 3 # rhand_joints = torch.matmul(rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # calculate_disp_quants_batched # velocity; and # for each of them --> (bsz x (nf - 1) x nn_joint) # # joints quantitiess # (joints, base_pts_trans, canon_joints, canon_base_normals) if self.args.use_temporal_rep_v2: disp_base_pts, dist_jts_to_base_pts, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched_v2(rhand_joints, base_pts, canon_rhand_joints, canon_base_normals) # print(f"disp_base_pts: {disp_base_pts.size()}, dist_jts_to_base_pts: {dist_jts_to_base_pts.size()}, dist_disp_along_dir: {dist_disp_along_dir.size()}, dist_disp_vt_dir: {dist_disp_vt_dir.size()},") joints_quants = torch.cat( [disp_base_pts, dist_jts_to_base_pts.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) joints_quants_pred = torch.cat( [dist_jts_to_base_pts.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) else: dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(rhand_joints, base_pts) # joints_quants: bsz x (nf - 1) x nnjoints x 3 --> joints_quants # joints_quants = torch.cat( # why we cannot predict this well ? # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) joints_quants_pred = joints_quants.clone() ### joints quants ### input_data = { 'joints_quants': joints_quants } # # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # # if self.args.train_diff: # with torch.no_grad(): # out_dict = model(input_data, self._scale_timesteps(t).clone()) # else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. terms['rot_mse'] = 0. ### diff_jts ### # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_basejtsrel: joints_quants_output = out_dict['joints_quants_output'] # joints_offset_output = out_dict['joints_offset_output'] # if self.args.pred_diff_noise: # basejtsrel_denoising_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 --> mean and mena over dim=-1 # else: # # basejtsrel denoising losses ## # basejtsrel_denoising_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # # if 'avg_jts_outputs' in out_dict: # avg jts outputs ## # avg_jts_outputs = out_dict['avg_jts_outputs'] # if self.args.pred_diff_noise: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.zeros_like(basejtsrel_denoising_loss) if self.args.joint_quants_nn == 2: # joint_quant_pred_loss = torch.sum( # (joints_quants_output[..., :2] - joints_quants[..., :2]) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1) joint_quant_pred_loss = torch.sum( (joints_quants_output[..., -2:] - joints_quants_pred[..., -2:]) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) else: joint_quant_pred_loss = torch.sum( (joints_quants_output - joints_quants_pred) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) joints_quants_latents = out_dict['joints_quants_latents'] terms['basejtrel_denoising_loss'] = joint_quant_pred_loss # terms['avgjts_denoising_loss'] = avgjts_denoising_loss # # terms['rot_mse'] += joint_quant_pred_loss # + avgjts_denoising_loss # jts denoising ## ## joint_quant_pred_loss -> quant pred loss # # joints_quants_latents: noise_joints_quants_latents = torch.randn_like(joints_quants_latents.detach()) # seq_len x bsz x latents pert_joints_quants_latents = self.q_sample( joints_quants_latents.permute(1, 0, 2).detach(), t, noise_joints_quants_latents.permute(1, 0, 2) ).permute(1, 0, 2) # denoising_joint_quants_feats --> denoising joint quants feats # denoised_joints_quants = model.model.denoising_joint_quants_feats(pert_joints_quants_latents, self._scale_timesteps(t).clone()) # joints_quants_latent_denoising_loss joints_quants_latent_denoising_loss = torch.sum( (denoised_joints_quants.permute(1, 0, 2) - noise_joints_quants_latents.permute(1, 0, 2)) ** 2, dim=-1 ).mean(dim=-1) # bsz size of the denoising loss # if not self.args.train_enc: terms['avgjts_denoising_loss'] = joints_quants_latent_denoising_loss terms['rot_mse'] += joints_quants_latent_denoising_loss # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } class GaussianDiffusionV8: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse self.diff_realbasejtsrel = self.args.diff_realbasejtsrel ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64)) # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ### variance xxx noise ### ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance_cond( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ # p_mean_varaince # if model_kwargs is None: model_kwargs = {} B = x['base_pts'].shape[0] assert t.shape == (B,) # print(f"t_shape: {t.shape}", "base_pts", x['base_pts'].size()) input_data = x out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} real_basejtsrel_seq_rt_dict = {} basejtsrel_seq_rt_dict = {} realbasejtsrel_to_joints_rt_dict = {} model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_realbasejtsrel and self.diff_basejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output # print(f"basejtsrel_output: {basejtsrel_output.size()}") # if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) # t_item = t[0].item() # alpha = self.var_sched.alphas[t_item] # alpha_bar = self.var_sched.alpha_bars[t_item] # sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # c0 = 1.0 / torch.sqrt(alpha) # c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # beta = self.var_sched.betas[[t[0].item()] * bsz] # z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) # basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta # else: # basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] # combine those two things # if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: # add noise onjts # if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints # jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # jts_fr_basepts = jts_fr_basepts.mean(dim=-2) jts_fr_basepts = pert_rel_base_pts_outputs # pert # score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 score_jts_fr_basepts = real_dec_basejtsrel[..., self.args.sel_basepts_idx, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] # combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts)[..., -5:, :] # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + score_jts_fr_basepts[..., -5:, :] * 0.3 # combined_socre = combined_socre * 0.2 + score_jts_fr_basepts * 0.8 combined_socre = combined_socre * 0.1 + score_jts_fr_basepts * 0.9 # combined_socre = combined_socre * 0.05 + score_jts_fr_basepts * 0.95 # combined_socre = combined_socre * 0.5 + score_jts_fr_basepts * 0.5 # combined_socre = combined_socre # combined_socre = score_jts_fr_basepts else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts combined_socre = score_jts # not cmb finger # # combined_socre = score_jts_fr_basepts # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, score_jts_fr_basepts: {score_jts_fr_basepts.size()}, combined_socre: {combined_socre.size()}, score_jts: {score_jts.size()}") if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. # use_var_sched -> # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # print(f"dec_jts_fr_basepts: {dec_jts_fr_basepts.size()}, normed_base_pts: ", x['normed_base_pts'].size(), "real_dec_basejtsrel:", real_dec_basejtsrel.size()) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, denormed_rel_base_pts_to_rhand_joints: {denormed_rel_base_pts_to_rhand_joints.size()}, jts_fr_basepts: {jts_fr_basepts.size()}") elif self.args.add_noise_onjts_single: # add noise on single joint jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] socre_jts_fr_basepts = dec_jts_fr_basepts if self.args.only_cmb_finger: combined_socre = score_jts # strategy 1 --> conditioning # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts)[..., -5:, :] # strategy 2 --> linear interpolation # # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + socre_jts_fr_basepts[..., -5:, :] * 0.3 combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.5 + socre_jts_fr_basepts[..., -5:, :] * 0.5 else: combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = socre_jts_fr_basepts if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(combined_socre) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (pert_rel_base_pts_outputs - c1 * combined_socre) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: dec_jts_fr_basepts = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: # raise ValueError(f"Add noise directly --- not implemented yet") # # input_data # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints if self.args.real_basejtsrel_norm_stra == "std": ### std denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints'] else: denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3 t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts combined_socre = score_jts_fr_basepts if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) if self.args.real_basejtsrel_norm_stra == "std": real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints'] # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_jts_fr_basepts } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints raise ValueError(f"Trian enc --- Not implemented yet") pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) # else: # real_basejtsrel_seq_rt_dict = {} # basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel and self.args.diff_realbasejtsrel_to_joints: pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] basejtsrel_output = out_dict['joints_offset_output'] score_jts = basejtsrel_output pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output_from_rel'] # joints offset output # score_jts_fr_rel = dec_joints_offset_output # # pert joints offset sequence # # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha_bar = self.var_sched.alpha_bars[t_item] combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_rel # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts # combined_socre = score_jts_fr_rel alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * score_jts) + sigma * z # theta ### realjtsrel_to_joints and joints only ## realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } basejtsrel_seq_rt_dict = { "basejtsrel_output": dec_joints_offset_output } # else: # realbasejtsrel_to_joints_rt_dict = {} # basejtsrel_seq_rt_dict = {} # rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) # rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### # B = x['base_pts'].shape[0] B = t.shape[0] assert t.shape == (B,) # input_data = x ## dec_out and out ## ## output dict ## # out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## posterior_log_variance_clipped, posterior_variance ## model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # pmean variance if self.diff_jts: # x_t ## joints seq latents ## pert_joints_seq_latents = x['joints_seq_latents'] # x_t pred_clean_joint_seq_latents = out_dict["joints_seq_latents"] ## if self.args.pred_diff_noise: ## eps -> estimated noises ## t > for added joints latents ## pred_clean_joint_seq_latents = self._predict_xstart_from_eps(pert_joints_seq_latents.permute(1, 0, 2), t=t, eps=pred_clean_joint_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # seq x bs x d # # minn_pert_joints_seq_latents, _ = torch.min(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # maxx_pert_joints_seq_latents, _ = torch.max(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # print(f"pred minn_pert_joints_latents: {minn_pert_joints_seq_latents[:10]}, pred maxx_pert_joints_seq_latents: {maxx_pert_joints_seq_latents[:10]}") ## out_dict["joint_seq_output"] = model.model.dec_jts_only_fr_latents(pred_clean_joint_seq_latents)["joint_seq_output"] ## joints seq latents mean # # pred_clean_joint_seq_latents = pert_joints_seq_latents ## joints seq latents mean # joints_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_clean_joint_seq_latents.permute(1, 0, 2), x_t=pert_joints_seq_latents.permute(1, 0, 2), t=t ) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_mean = joints_seq_latents_mean.permute(1, 0, 2) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_variance = _extract_into_tensor(model_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joints_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # joint seq output # # joint seq output # joint_seq_output = out_dict["joint_seq_output"] jts_seq_rt_dict = { ### joints seq latents ### "joints_seq_latents_mean": joints_seq_latents_mean, "joints_seq_latents_variance": joints_seq_latents_variance, "joints_seq_latents_log_variance": joints_seq_latents_log_variance, ### decoded output values ### "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} # /data1/sim/mdm/save/predoffset_stdscale_bsz_10_pred_diff_realbaesjtsrel_nonorm_std_for_norm_train_enc_with_diff_latents_prediffnoise_none_norm_rel_rel_to_jts_/model000007000.pt if self.args.diff_realbasejtsrel_to_joints: pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs # dec_joints_offset_output = out_dict['joints_offset_output'] # joints offset output # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # if self.args.use_var_sched: bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ## t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> beta, z c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * dec_joints_offset_output) + sigma * z # theta ## use the predicted latents and pert_latents for the seq latents prediction ## dec_joints_offset_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=dec_joints_offset_output, x_t=pert_joints_offset_output, t=t ) ## from model_variance to basejtsrel_seq_latents ### dec_joints_offset_output_variance = _extract_into_tensor(model_variance, t, dec_joints_offset_output.shape) dec_joints_offset_output_log_variance = _extract_into_tensor(model_log_variance, t, dec_joints_offset_output.shape) realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': dec_joints_offset_output, } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3 pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints'] if self.args.pred_diff_noise and not self.args.train_enc: if self.args.add_noise_onjts: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) dec_jts_fr_basepts = real_dec_basejtsrel # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # # real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta ## dec jts fr base pts ## # dec jts fr # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) elif self.args.add_noise_onjts_single: jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) jts_fr_basepts = jts_fr_basepts.mean(dim=-2) # dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :] if self.args.use_var_sched: bsz = dec_jts_fr_basepts.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts) dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 # else: # x_{t-1} dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts) real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) else: if self.args.use_var_sched: bsz = real_dec_basejtsrel.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma = sigma / 2. c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(pert_rel_base_pts_to_rhand_joints) if t_item > 0 else torch.zeros_like(pert_rel_base_pts_to_rhand_joints) # z = torch.zeros_like(pert_rel_base_pts_to_rhand_joints) real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta # dec_jts_fr_basepts = real_dec_basejtsrel + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # get dec_jts fr basepts # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # repeated basepts # real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # real dec # else: # x_{t-1} real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel) ## use the predicted latents and pert_latents for the seq latents prediction ## real_dec_basejtsrel_mean, _, _ = self.q_posterior_mean_variance( # q posterior x_start=real_dec_basejtsrel, x_t=pert_rel_base_pts_to_rhand_joints, t=t ) ## from model_variance to basejtsrel_seq_latents ### real_dec_basejtsrel_variance = _extract_into_tensor(model_variance, t, real_dec_basejtsrel.shape) real_dec_basejtsrel_log_variance = _extract_into_tensor(model_log_variance, t, real_dec_basejtsrel.shape) # diff_realbasejtsrel real_basejtsrel_seq_rt_dict = { "real_dec_basejtsrel": real_dec_basejtsrel, "real_dec_basejtsrel_mean": real_dec_basejtsrel_mean, "real_dec_basejtsrel_variance": real_dec_basejtsrel_variance, "real_dec_basejtsrel_log_variance": real_dec_basejtsrel_log_variance, } if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints pert_obj_base_pts_feats = x['pert_obj_base_pts_feats'] dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t) dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2) if self.args.pred_diff_noise: bsz = dec_obj_base_pts_feats.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats) dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2) real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x) real_basejtsrel_seq_rt_dict.update( { 'real_dec_basejtsrel': real_dec_basejtsrel, 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, } ) else: real_basejtsrel_seq_rt_dict = {} if self.diff_basejtsrel: # pert joints quants latents # pert_joints_quants_latents = x['pert_joints_quants_latents'] # get the perturbed joints quants latents # # denoised latents # # self._scale_timesteps(t).clone() # denoised_latents: seq_lent x bsz x latent_dim # we use denoised latents # denoised_latents = model.model.denoising_joint_quants_feats(pert_joints_quants_latents, self._scale_timesteps(t).clone()) # basejtsrel seq rt dict # # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # pert_joints_quants_latents # basejtsrel_output = out_dict['joints_offset_output'] if self.args.use_var_sched: # bsz = basejtsrel_output.size(0) bsz = B t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # x_t = traj[t] # beta = self.var_sched.betas[[t[0].item()] * bsz] # if mask is not None: # x_t = x_t * mask # e_theta = self.net(x_t, beta=beta, context=context) # denoising z = torch.randn_like(pert_joints_quants_latents) if t_item > 0 else torch.zeros_like(pert_joints_quants_latents) denoised_latents = c0 * (pert_joints_quants_latents - c1 * denoised_latents) + sigma * z # theta else: denoised_latents = self._predict_xstart_from_eps(pert_joints_quants_latents, t=t, eps=denoised_latents) if self.args.train_enc: # joints_quants_latents # pred joint quants pred_joint_quants = model.model.pred_joint_quants_from_latent_feats(x['joints_quants_latents']) # joints_quants_latents for quants_latents # else: pred_joint_quants = model.model.pred_joint_quants_from_latent_feats(denoised_latents) # basejtsrel_output = out_dict["basejtsrel_output"] # print(f"basejtsrel_output: {basejtsrel_output.size()}") basejtsrel_seq_rt_dict = { ### basejtsrel seq latents ### # "avg_jts_outputs": avg_jts_outputs, # "basejtsrel_output_variance": basejtsrel_output_variance, # "basejtsrel_output_log_variance": basejtsrel_output_log_variance, # # "avg_jts_outputs_variance": avg_jts_outputs_variance, # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance, # "basejtsrel_output": basejtsrel_output, # pred joint quants -> 'pred_joint_quants': pred_joint_quants, 'denoised_latents': denoised_latents, } else: basejtsrel_seq_rt_dict = {} if self.diff_basejtse: dec_e_along_normals = out_dict['dec_e_along_normals'] dec_e_vt_normals = out_dict['dec_e_vt_normals'] pert_e_along_normals = x['pert_e_disp_rel_to_base_along_normals'] pert_e_vt_normals = x['pert_e_disp_rel_to_base_vt_normals'] # pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs # # basejtsrel_output = out_dict['joints_offset_output'] if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # # b if self.args.use_var_sched: bsz = dec_e_along_normals.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) beta = self.var_sched.betas[[t[0].item()] * bsz] z = torch.randn_like(dec_e_along_normals) if t_item > 0 else torch.zeros_like(dec_e_along_normals) dec_e_along_normals = c0 * (pert_e_along_normals - c1 * dec_e_along_normals) + sigma * z # theta z_vt_normals = torch.randn_like(dec_e_vt_normals) if t_item > 0 else torch.zeros_like(dec_e_vt_normals) dec_e_vt_normals = c0 * (pert_e_vt_normals - c1 * dec_e_vt_normals) + sigma * z_vt_normals # theta else: dec_e_along_normals = self._predict_xstart_from_eps(pert_e_along_normals, t=t, eps=dec_e_along_normals) dec_e_vt_normals = self._predict_xstart_from_eps(pert_e_vt_normals, t=t, eps=dec_e_vt_normals) # base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### # pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] # ### q-sampled latent mean here ### # basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( # x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t # ) # basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) # basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # # base_jts_e_feats = out_dict["base_jts_e_feats"] # dec_e_along_normals = out_dict["dec_e_along_normals"] # dec_e_vt_normals = out_dict["dec_e_vt_normals"] basejtse_seq_rt_dict = { ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_seq_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) rt_dict.update(real_basejtsrel_seq_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). # """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def judge_activated(self, target_setting): if target_setting: return 1 else: return 0 def p_sample( ## p sample ## self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. # gaussian diffusion # :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ multi_activated = ( self.judge_activated(self.diff_jts) + self.judge_activated(self.args.diff_realbasejtsrel_to_joints) + self.judge_activated(self.diff_realbasejtsrel) + self.judge_activated(self.diff_basejtsrel) + self.judge_activated(self.diff_basejtse) ) > 1.5 if multi_activated: # print(f"Multiple settings activated! Using combined sampling...") p_mena_variance_fn = self.p_mean_variance_cond else: # print(f"Single setting activated! Using single sampling...") p_mena_variance_fn = self.p_mean_variance out = p_mena_variance_fn( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} if self.diff_jts: # bsz x ws x nnj x nnb x 3 # joints_seq_latents_noise = th.randn_like(x['joints_seq_latents']) # print('const_noise', const_noise) # seq x bsz x latent_dim # if const_noise: print(f"joints latents hape, ", x['joints_seq_latents'].shape) joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1) # joints_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1))) ) # no noise when t == 0 # bsz x nseq x dim # #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... # joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2) # nseq x bsz x dim # joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # #### ==== joints_seq_latents ===== #### joint_seq_output = out["joint_seq_output"] jts_seq_rt_dict = { "joints_seq_latents_sample": joints_seq_latents_sample, "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.args.diff_realbasejtsrel_to_joints: ## args.pred to joints # dec_joints_offset_output = realbasejtsrel_to_joints_rt_dict = { 'dec_joints_offset_output': out['dec_joints_offset_output'] } else: realbasejtsrel_to_joints_rt_dict = {} if self.diff_realbasejtsrel: if self.args.train_enc or ( self.args.pred_diff_noise and self.args.use_var_sched): real_dec_basejtsrel = out['real_dec_basejtsrel'] else: real_dec_basejtsrel_noise = th.randn_like(out['real_dec_basejtsrel']) if const_noise: real_dec_basejtsrel_noise = real_dec_basejtsrel_noise[[0]].repeat(out['real_dec_basejtsrel'].shape[0], 1, 1, 1, 1) real_dec_basejtsrel_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['real_dec_basejtsrel'].shape) - 1))) ) real_dec_basejtsrel = out["real_dec_basejtsrel_mean"] + real_dec_basejtsrel_nonzero_mask * th.exp(0.5 * out["real_dec_basejtsrel_log_variance"]) * real_dec_basejtsrel_noise real_basejtsrel_rt_dict = { 'real_dec_basejtsrel': real_dec_basejtsrel, } if self.args.train_enc: real_basejtsrel_rt_dict['dec_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] else: real_basejtsrel_rt_dict = {} if self.diff_basejtsrel: basejtsrel_rt_dict = out else: basejtsrel_rt_dict = {} if self.diff_basejtse: # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### # ### rel_base_pts_outputs mask ### # basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # # print('const_noise', const_noise) # if const_noise: # basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) # basejtse_seq_latents_nonzero_mask = ( # (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) # ) # no noise when t == 0 # #### ==== basejtsrel_seq_latents ===== #### # basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) # basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) # #### ==== basejtsrel_seq_latents ===== #### # ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## dec_e_vt_normals = out["dec_e_vt_normals"] basejtse_rt_dict = { # "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) rt_dict.update(real_basejtsrel_rt_dict) rt_dict.update(realbasejtsrel_to_joints_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final # score # socre p_sample_loop_progressive # def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # p_sample_loop_progresive # """ # p_sample loop progressive # Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ### sample progressive ### ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] # rhand_joints --> rhand_joints for th # rhand_joints = init_image['gt_rhand_joints'] ### init_image ### obj_rot = init_image['obj_rot'] # bsz x nf x 3 x 3 --> obj_rot obj_transl = init_image['obj_transl'] # obj_rot; obj_trans # # obj_transl: bsz x nf x 3 --> transl of obj base # base-pts base_pts = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # base_normals = torch.matmul(base_normals.unsqueeze(1), obj_rot) # bsz x nf x nn_base_pts x 3 # if len(self.args.predicted_info_fn) > 0: print(f"Joints are not transformed here!") rhand_joints = rhand_joints # obj_rot, obj_trans # else: rhand_joints = torch.matmul(rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # gt_rhand_joints = init_image['gt_rhand_joints'] gt_rhand_joints = torch.matmul(gt_rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # calculate_disp_quants_batched # rhand_joints -> rhand_joints # for each of them --> (bsz x (nf - 1) x nn_joint) # # joints quantitiess dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(rhand_joints, base_pts) gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir = calculate_disp_quants_batched(gt_rhand_joints, base_pts) # joints_quants: bsz x (nf - 1) x nnjoints x 3 --> joints_quants # joints_quants = torch.cat( [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) # gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir # diff_pert_gt_joints_to_base_pts_disp = torch.mean( (dist_joints_to_base_pts_disp - gt_dist_joints_to_base_pts_disp) ** 2 ) diff_pert_gt_dist_disp_along_dir = torch.mean( (dist_disp_along_dir - gt_dist_disp_along_dir) ** 2 ) diff_pert_gt_dist_disp_vt_dir = torch.mean( (dist_disp_vt_dir - gt_dist_disp_vt_dir) ** 2 ) ### diff pert gt joints to base pts disp; diff_pert_gt_dist_disp_along_dir; diff_pert_gt_dist_disp_vt_dir ### print(f"diff_pert_gt_joints_to_base_pts_disp: {diff_pert_gt_joints_to_base_pts_disp.item()}, diff_pert_gt_dist_disp_along_dir: {diff_pert_gt_dist_disp_along_dir.item()}, diff_pert_gt_dist_disp_vt_dir: {diff_pert_gt_dist_disp_vt_dir.item()}") # joints_quants = torch.cat( # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1)], dim=-1 # ) # joints_quants = torch.cat( # [dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 # ) # x = joints_quants input_data = { 'joints_quants': joints_quants } ### without e normalization ### # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") # joints_s my_t = th.tensor([indices[0]] * shape[0], device=device).cuda() ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(my_t).clone()) ### get model output dictionary ### joints_quants_latents = out_dict['joints_quants_latents'] pred_joint_quants = out_dict['joints_quants_output'] # joints_quants_latents noise_joints_quants_latents = th.randn_like(joints_quants_latents) pert_joints_quants_latents = self.q_sample( # # joints_quants_latents joints_quants_latents.permute(1, 0, 2), my_t, noise_joints_quants_latents.permute(1, 0, 2) # ).permute(1, 0, 2) # seq_len x bsz x latent_dim # # input_data = { # 'base_pts': base_pts, # 'base_normals': base_normals, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # # 'pert_rhand_joints': pert_normed_rhand_joints, # # 'pert_rhand_joints': pert_scaled_rhand_joints, # 'rhand_joints': rhand_joints, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # # 'avg_joints_sequence': avg_joints_sequence, # # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence # # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'pert_joints_offset_sequence': pert_joints_offset_sequence, # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # # 'pert_joints_offset_sequence': pert_joints_offset_sequence, # 'normed_base_pts': normed_base_pts, # 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, ## pert_rel_base_pts_to_joints_for_jts_pred for the bsz x nf x nnj x nnb x 3 --> from base points to joints #### 'pert_joints_quants_latents': pert_joints_quants_latents, # add perturbed information # 'joints_quants_latents': joints_quants_latents, # clean latent vectors # } # if self.args.train_enc: # # diff_joints_quants_pred_joints_quants_dist = torch.sum( # (joints_quants[..., 0:1] - pred_joint_quants[..., 0:1]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_dist: {diff_joints_quants_pred_joints_quants_dist.item()}") ### diff_joints_quants_pred_joints_quants # diff_joints_quants_pred_joints_quants_along_normals = torch.sum( # (joints_quants[..., 1:2] - pred_joint_quants[..., 1:2]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_along_normals: {diff_joints_quants_pred_joints_quants_along_normals.item()}") # diff_joints_quants_pred_joints_quants_vt_normals = torch.sum( # (joints_quants[..., 2:] - pred_joint_quants[..., 2:]) ** 2, dim=-1 # ).mean() # print(f"[train_enc] Current step diff_joints_quants_pred_joints_quants_vt_normals: {diff_joints_quants_pred_joints_quants_vt_normals.item()}") # input_data.update( # {'pred_joint_quants': pred_joint_quants} # ) # for t in range(len(indices)): # yield input_data # return # primal space denoising -> # input_data.update( # { # 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, # 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, # 'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals, # 'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals, # } # ) # input # # input_data.update(init_image_avg_std_stats) # input_data['rhand_joints'] = rhand_joints # normed # self.args.real_basejtsrel_norm_stra == "std": # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints # if self.args.real_basejtsrel_norm_stra == "std": # input_data.update( # { # 'avg_rel_base_pts_to_rhand_joints': avg_rel_base_pts_to_rhand_joints, # 'std_rel_base_pts_to_rhand_joints': std_rel_base_pts_to_rhand_joints, # } # ) if self.args.train_enc: # # # model(input_data, self._scale_timesteps(t).clone()) # out_dict = model(input_data, self._scale_timesteps(my_t).clone()) # obj_base_pts_feats = out_dict['obj_base_pts_feats'] # obj base pts feats # # # noise_obj_base_pts_feats = torch.zeros_like(obj_base_pts_feats) # noise_obj_base_pts_feats = torch.randn_like(obj_base_pts_feats) # pert_obj_base_pts_feats = self.q_sample( # obj_base_pts_feats.permute(1, 0, 2), my_t, noise_obj_base_pts_feats.permute(1, 0, 2) # ).permute(1, 0, 2) # if self.args.rnd_noise: # pert_obj_base_pts_feats = noise_obj_base_pts_feats # input_data['pert_obj_base_pts_feats'] = pert_obj_base_pts_feats # joints_quants_latents input_data['pert_joints_quants_latents'] = joints_quants_latents model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. # from tqdm.auto import tqdm indices = tqdm(indices) for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # progress # # p_sample_with_grad ## p_sample with grid ##s # or for each joints -> the features -> sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample out = sample_fn( model, input_data, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, ## const_noise=const_noise, ## # new representation strategies; resolve penerations; resolve penerations ### penetrations for new representations ## ) if self.diff_basejtsrel: # 'pred_joint_quants': pred_joint_quants, # 'denoised_latents': denoised_latents, pred_joint_quants = out['pred_joint_quants'] denoised_latents = out['denoised_latents'] # denoised latents # # basejtsrel_seq_latents_sample = out["basejtsrel_seq_latents_sample"] ## basejtsrle output ## ## ## basejtsrel output ## # 'real_dec_basejtsrel': real_dec_basejtsrel, # 'dec_obj_base_pts_feats': dec_obj_base_pts_feats, # if self.args.pred_joints_offset: # pred # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3 # basejtsrel_seq_latents_sample --> basejtsrel_seq_latents_sample # # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1) # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) diff_joints_quants_pred_joints_quants_dist = torch.sum( (joints_quants[..., 0:1] - pred_joint_quants[..., 0:1]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_dist: {diff_joints_quants_pred_joints_quants_dist.item()}") ### diff_joints_quants_pred_joints_quants diff_joints_quants_pred_joints_quants_along_normals = torch.sum( (joints_quants[..., 1:2] - pred_joint_quants[..., 1:2]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_along_normals: {diff_joints_quants_pred_joints_quants_along_normals.item()}") diff_joints_quants_pred_joints_quants_vt_normals = torch.sum( (joints_quants[..., 2:] - pred_joint_quants[..., 2:]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_vt_normals: {diff_joints_quants_pred_joints_quants_vt_normals.item()}") # gt_dist_joints_to_base_pts_disp, gt_dist_disp_along_dir, gt_dist_disp_vt_dir diff_joints_quants_pred_joints_quants_dist_gt = torch.sum( (gt_dist_joints_to_base_pts_disp.unsqueeze(-1) - pred_joint_quants[..., 0:1]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_dist_gt: {diff_joints_quants_pred_joints_quants_dist_gt.item()}") ### diff_joints_quants_pred_joints_quants diff_joints_quants_pred_joints_quants_along_normals_gt = torch.sum( (gt_dist_disp_along_dir.unsqueeze(-1) - pred_joint_quants[..., 1:2]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_along_normals_gt: {diff_joints_quants_pred_joints_quants_along_normals_gt.item()}") diff_joints_quants_pred_joints_quants_vt_normals_gt = torch.sum( (gt_dist_disp_vt_dir.unsqueeze(-1) - pred_joint_quants[..., 2:]) ** 2, dim=-1 ).mean() print(f"Current step diff_joints_quants_pred_joints_quants_vt_normals_gt: {diff_joints_quants_pred_joints_quants_vt_normals_gt.item()}") basejtsrel_seq_dec_in_dict = { # finetune_with_cond # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ## # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## # 'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # rhand joints ## ## and another choice ## another choice ## # 'pert_joints_offset_sequence': out["basejtsrel_seq_latents_sample"], 'pert_joints_quants_latents': denoised_latents, 'pred_joint_quants': pred_joint_quants, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtsrel_seq_dec_in_dict) else: # basejtsrel_seq_input_dict = {} basejtsrel_seq_dec_in_dict = {} if self.args.diff_realbasejtsrel_to_joints: # predicted x_{t-1} (normalized) ## # rel to joints dec_joints_offset_output = out['dec_joints_offset_output'] ## minus normed base pts here ## # from normed pts and offset outputs # pert_rel_base_pts_to_joints_for_jts_pred = dec_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) # predicted x_{t-1} before normalization # sampled_rhand_joints = dec_joints_offset_output * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) realbasejtsrel_to_joints_dec_in_dict = { 'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, # bsz x nf x nnj x nnb x 3 ## 'sampled_rhand_joints': sampled_rhand_joints, 'pert_joints_offset_sequence': dec_joints_offset_output, } input_data.update(realbasejtsrel_to_joints_dec_in_dict) if self.diff_realbasejtsrel : # real_dec_basejtsrel = out["real_dec_basejtsrel"] # bsz x nf x nnj x nnb x 3 # # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) # add_noise_onjts, add_noise_onjts_single #### add_noise_onjts; add_noise_onjts_single #### if self.args.real_basejtsrel_norm_stra == "std" and (not self.args.add_noise_onjts) and (not self.args.add_noise_onjts_single): real_dec_basejtsrel_pred_sample = real_dec_basejtsrel * std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) + avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) else: # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel.clone() real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(1).unsqueeze(1) # real dec basejtsrel pred sample # # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 # # # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) # real pred samples # if self.args.use_t == 1000: sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) else: sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # bsz x nf x nnj x 3 # # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., 0, :] # std joints sequence; # std_joints # sampled_rhand_joints = sampled_rhand_joints * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) real_basejtsrel_dec_in_dict = { # real_dec_basejtsrel # 'pert_rel_base_pts_to_rhand_joints': real_dec_basejtsrel, ## realdecbasejtsrel # 'sampled_rhand_joints': sampled_rhand_joints, } if not self.diff_basejtsrel: real_basejtsrel_dec_in_dict['sampled_rhand_joints'] = sampled_rhand_joints if self.args.train_enc: real_basejtsrel_dec_in_dict['pert_obj_base_pts_feats'] = out['dec_obj_base_pts_feats'] input_data.update(real_basejtsrel_dec_in_dict) else: real_basejtsrel_dec_in_dict = {} if self.diff_basejtse: ## seq latents ## seq latents ## # basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"] pert_dec_e_along_normals = out["dec_e_along_normals"] pert_dec_e_vt_normals = out["dec_e_vt_normals"] dec_e_along_normals = pert_dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals'] dec_e_vt_normals = pert_dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals'] ## dec_e_along_normals = torch.clamp(dec_e_along_normals, min=0.) dec_e_vt_normals = torch.clamp(dec_e_vt_normals, min=0.) # scale base ## model constraints and model impacts from object a to object c ## basejtse_seq_input_dict = { 'pert_e_disp_rel_to_base_along_normals': pert_dec_e_along_normals, 'pert_e_disp_rel_to_base_vt_normals': pert_dec_e_vt_normals, 'e_disp_rel_to_base_along_normals': dec_e_along_normals, 'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, 'sampled_rhand_joints': rhand_joints, } input_data.update(basejtse_seq_input_dict) else: basejtse_seq_input_dict = {} # basejtse_seq_dec_in_dict = {} yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # if self.args.train_diff: # set enc to evals # # # print(f"Setitng encoders to eval mode") # model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # training lossesss # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # base normals # base normals # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] obj_rot = x_start['obj_rot'] # bsz x nf x 3 x 3 --> obj_rot obj_transl = x_start['obj_transl'] # obj_rot; obj_trans # # obj_transl: bsz x nf x 3 --> transl of obj base # base_pts = torch.matmul(base_pts.unsqueeze(1), obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # base_normals = torch.matmul(base_normals.unsqueeze(1), obj_rot) # bsz x nf x nn_base_pts x 3 # rhand_joints = torch.matmul(rhand_joints, obj_rot) + obj_transl.unsqueeze(-2) # bsz x nf x nn_base_pts x 3 # # dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir # calculate_disp_quants_batched # for each of them --> (bsz x (nf - 1) x nn_joint) # # joints quantitiess dist_joints_to_base_pts_disp, dist_disp_along_dir, dist_disp_vt_dir = calculate_disp_quants_batched(rhand_joints, base_pts) # joints_quants: bsz x (nf - 1) x nnjoints x 3 --> joints_quants # joints_quants = torch.cat( # why we cannot predict this well ? # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 ) # joints_quants = torch.cat( # [dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1), dist_joints_to_base_pts_disp.unsqueeze(-1)], dim=-1 # ) # joints_quants = torch.cat( # [dist_disp_along_dir.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1), dist_disp_along_dir.unsqueeze(-1)], dim=-1 # ) # joints_quants = torch.cat( # [dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1), dist_disp_vt_dir.unsqueeze(-1)], dim=-1 # ) # joints_quants # x = joints_quants # pert_e_disp_rel_to_base_vt_normals = self.q_sample( # e_disp_rel_to_baes_vt_normals, t, noise_e_disp_rel_to_base_vt_normals # ) ### joints quants ### input_data = { 'joints_quants': joints_quants } # # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # # if self.args.train_diff: # with torch.no_grad(): # out_dict = model(input_data, self._scale_timesteps(t).clone()) # else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. terms['rot_mse'] = 0. ### diff_jts ### # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_basejtsrel: # diff_joint_quants # # if 'basejtsrel_output' in out_dict: # basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous() # avg_jts_outputs = out_dict['avg_jts_outputs'] # # print(f"basejtsrel_output: {basejtsrel_output.size()}, noise_rel_base_pts_to_rhand_joints: {noise_rel_base_pts_to_rhand_joints.size()}, rel_base_pts_to_rhand_joints: {rel_base_pts_to_rhand_joints.size()}") # if self.args.pred_diff_noise: # basejtsrel_denoising_loss = torch.sum((basejtsrel_output - noise_rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) # avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # basejtsrel_denoising_loss = torch.sum((basejtsrel_output - rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1) # avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: joints_quants_output = out_dict['joints_quants_output'] # joints_offset_output = out_dict['joints_offset_output'] # if self.args.pred_diff_noise: # basejtsrel_denoising_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 --> mean and mena over dim=-1 # else: # # basejtsrel denoising losses ## # basejtsrel_denoising_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # # if 'avg_jts_outputs' in out_dict: # avg jts outputs ## # avg_jts_outputs = out_dict['avg_jts_outputs'] # if self.args.pred_diff_noise: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1) # else: # avgjts_denoising_loss = torch.zeros_like(basejtsrel_denoising_loss) if self.args.joint_quants_nn == 2: # joint_quant_pred_loss = torch.sum( # (joints_quants_output[..., :2] - joints_quants[..., :2]) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1) joint_quant_pred_loss = torch.sum( (joints_quants_output[..., 1:] - joints_quants[..., 1:]) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) else: joint_quant_pred_loss = torch.sum( (joints_quants_output - joints_quants) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) joints_quants_latents = out_dict['joints_quants_latents'] terms['basejtrel_denoising_loss'] = joint_quant_pred_loss # terms['avgjts_denoising_loss'] = avgjts_denoising_loss # # terms['rot_mse'] += joint_quant_pred_loss # + avgjts_denoising_loss # jts denoising ## ## joint_quant_pred_loss -> quant pred loss # # if self.diff_basejtse: # ### Sample for perturbed basejtsrel seq latents ### # dec_e_along_normals = out_dict['dec_e_along_normals'] # dec_e_vt_normals = out_dict['dec_e_vt_normals'] # # pred_e_along_vt_loss, # # dec_e_along_normals; dec_e_vt_normals; # # bszx xnnj x nnb x 1 # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals 3 # if self.args.pred_diff_noise: ## predict # predict e along and vt nromals ## # pred_e_along_normals_loss = ((dec_e_along_normals - noise_e_disp_rel_to_base_along_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # pred_e_along_vt_loss = ((dec_e_vt_normals - noise_e_disp_rel_to_base_vt_normals) ** 2).mean(dim=-1).mean(dim=-1).mean(dim=-1) # terms['basejtse_along_normals_pred_loss'] = pred_e_along_normals_loss # terms['basejtse_vt_normals_pred_loss'] = pred_e_along_vt_loss # terms['rot_mse'] += pred_e_along_normals_loss + pred_e_along_vt_loss # sv_inter_dict = { # 'dec_joints': dec_clean_joint_seq.detach().cpu().numpy(), # 'rhand_joints': rhand_joints.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } # # sv_inter_dict_fn = os.path.join(args.save_dir, ) # joints_quants_latents: noise_joints_quants_latents = torch.randn_like(joints_quants_latents.detach()) # seq_len x bsz x latents pert_joints_quants_latents = self.q_sample( joints_quants_latents.permute(1, 0, 2).detach(), t, noise_joints_quants_latents.permute(1, 0, 2) ).permute(1, 0, 2) # denoising_joint_quants_feats --> denoising joint quants feats # denoised_joints_quants = model.model.denoising_joint_quants_feats(pert_joints_quants_latents, self._scale_timesteps(t).clone()) # joints_quants_latent_denoising_loss joints_quants_latent_denoising_loss = torch.sum( (denoised_joints_quants.permute(1, 0, 2) - noise_joints_quants_latents.permute(1, 0, 2)) ** 2, dim=-1 ).mean(dim=-1) # bsz size of the denoising loss # if not self.args.train_enc: terms['avgjts_denoising_loss'] = joints_quants_latent_denoising_loss terms['rot_mse'] += joints_quants_latent_denoising_loss # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } # not used now # class GaussianDiffusionV6: """ Utilities for training and sampling diffusion models. Ported directly from here, and then adapted over time to further experimentation. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ## starting at T and going to 1. :param model_mean_type: a ModelMeanType determining what the model outputs. :param model_var_type: a ModelVarType determining how variance is output. :param loss_type: a LossType determining the loss function to use. :param rescale_timesteps: if True, pass floating point timesteps into the model so that they are always scaled like in the original paper (0 to 1000). """ def __init__( self, *, betas, model_mean_type, model_var_type, loss_type, rescale_timesteps=False, lambda_rcxyz=0., lambda_vel=0., lambda_pose=1., lambda_orient=1., lambda_loc=1., data_rep='rot6d', lambda_root_vel=0., lambda_vel_rcxyz=0., lambda_fc=0., denoising_stra="rep", inter_optim=False, args=None, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type ## model var type ## self.loss_type = loss_type self.rescale_timesteps = rescale_timesteps self.data_rep = data_rep self.args = args # possibly None ### GET the diff. suit ### self.diff_jts = self.args.diff_jts self.diff_basejtsrel = self.args.diff_basejtsrel self.diff_basejtse = self.args.diff_basejtse ### GET the diff. suit ### if data_rep != 'rot_vel' and lambda_pose != 1.: raise ValueError('lambda_pose is relevant only when training on velocities!') self.lambda_pose = lambda_pose self.lambda_orient = lambda_orient self.lambda_loc = lambda_loc self.lambda_rcxyz = lambda_rcxyz self.lambda_vel = lambda_vel self.lambda_root_vel = lambda_root_vel self.lambda_vel_rcxyz = lambda_vel_rcxyz self.lambda_fc = lambda_fc ### === denoising_stra for the denoising process === ### self.denoising_stra = denoising_stra self.inter_optim = inter_optim if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.: assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!' self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64)) # Use float64 for accuracy. betas = np.array(betas, dtype=np.float64) self.betas = betas assert len(betas.shape) == 1, "betas must be 1-D" ## betas assert (betas > 0).all() and (betas <= 1).all() self.num_timesteps = int(betas.shape[0]) alphas = 1.0 - betas # self.alphas_cumprod = np.cumprod(alphas, axis=0) # alpha_bars self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. self.posterior_log_variance_clipped = np.log( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) ) self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on. ''' Load statistics ''' avg_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/avg_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" std_joints_motion_dists_ours_fn = f"/home/xueyi/sim/motion-diffusion-model/std_joints_dist_motion_ours_nb_{700}_nth_{0.005}.npy" avg_joints_rel = np.load(avg_joints_motion_ours_fn, allow_pickle=True) std_joints_rel = np.load(std_joints_motion_ours_fn, allow_pickle=True) avg_joints_dists = np.load(avg_joints_motion_dists_ours_fn, allow_pickle=True) std_joints_dists = np.load(std_joints_motion_dists_ours_fn, allow_pickle=True) ## self.avg_joints_rel, self.std_joints_rel ## self.avg_joints_dists, self.std_joints_dists self.avg_joints_rel = torch.from_numpy(avg_joints_rel).float() self.std_joints_rel = torch.from_numpy(std_joints_rel).float() self.avg_joints_dists = torch.from_numpy(avg_joints_dists).float() self.std_joints_dists = torch.from_numpy(std_joints_dists).float() ''' Load statistics ''' ''' Load avg, std statistics ''' # self.maxx_rel, minn_rel, maxx_dists, minn_dists # rel_dists_stats_fn = "/home/xueyi/sim/motion-diffusion-model/base_pts_rel_dists_stats.npy" rel_dists_stats = np.load(rel_dists_stats_fn, allow_pickle=True).item() maxx_rel = rel_dists_stats['maxx_rel'] minn_rel = rel_dists_stats['minn_rel'] maxx_dists = rel_dists_stats['maxx_dists'] minn_dists = rel_dists_stats['minn_dists'] self.maxx_rel = torch.from_numpy(maxx_rel).float() self.minn_rel = torch.from_numpy(minn_rel).float() self.maxx_dists = torch.from_numpy(maxx_dists).float() self.minn_dists = torch.from_numpy(minn_dists).float() ''' Load avg, std statistics ''' ''' Load avg-jts, std-jts ''' avg_jts_fn = "/home/xueyi/sim/motion-diffusion-model/avg_joints_motion_ours.npy" std_jts_fn = "/home/xueyi/sim/motion-diffusion-model/std_joints_motion_ours.npy" avg_jts = np.load(avg_jts_fn, allow_pickle=True) std_jts = np.load(std_jts_fn, allow_pickle=True) # self.avg_jts, self.std_jts # self.avg_jts = torch.from_numpy(avg_jts).float() self.std_jts = torch.from_numpy(std_jts).float() ''' Load avg-jts, std-jts ''' def masked_l2(self, a, b, mask): # assuming a.shape == b.shape == bs, J, Jdim, seqlen # assuming mask.shape == bs, 1, 1, seqlen loss = self.l2_loss(a, b) loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements n_entries = a.shape[1] * a.shape[2] non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements # print('mask', mask.shape) # print('non_zero_elements', non_zero_elements) # print('loss', loss) mse_loss_val = loss / non_zero_elements # print('mse_loss_val', mse_loss_val) return mse_loss_val def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). # q-mean-variance # :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): """ Diffuse the dataset for a given number of diffusion steps. In other words, sample from q(x_t | x_0). ## q pos :param x_start: the initial dataset batch. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param noise: if specified, the split-out normal noise. :return: A noisy version of x_start. """ if noise is None: noise = th.randn_like(x_start) assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ### variance xxx noise ### ) ## q_sample, def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior """ Compute the mean and variance of the diffusion posterior: q(x_{t-1} | x_t, x_0) """ assert x_start.shape == x_t.shape posterior_mean = ( # posterior mean and variance # _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( ## get mean data ## self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. :param model: the model, which takes a signal and a batch of timesteps as input. :param x: the [N x C x ...] tensor at time t. :param t: a 1-D Tensor of timesteps. :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ## :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. Applies before clip_denoised. # denoised fn :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict with the following keys: - 'mean': the model mean output. - 'variance': the model variance output. - 'log_variance': the log of 'variance'. - 'pred_xstart': the prediction for x_0. """ if model_kwargs is None: model_kwargs = {} ## === version 1 -> predict x_start at the rel domain === ## ### == x -> formulated as model_inputs == ### ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ### B = x['input_data']['base_pts'].shape[0] assert t.shape == (B,) input_data = x['input_data'] ## dec_out and out ## ## output dict ## out_dict = model.model.dec_latents_to_joints_with_t(x, input_data, self._scale_timesteps(t).clone()) # out_dict = model(input_data, self._scale_timesteps(t).clone()) rt_dict = {} # # }[self.model_var_type] # ### === model variance and log_variance === ### ## self.posterior_variance, self. model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped if self.diff_jts: # x_t ## joints seq latents ## pert_joints_seq_latents = x['joints_seq_latents'] # x_t pred_clean_joint_seq_latents = out_dict["joints_seq_latents"] if self.args.pred_diff_noise: ## eps -> estimated noises ## t > for added joints latents ## pred_clean_joint_seq_latents = self._predict_xstart_from_eps(pert_joints_seq_latents.permute(1, 0, 2), t=t, eps=pred_clean_joint_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # seq x bs x d # minn_pert_joints_seq_latents, _ = torch.min(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # maxx_pert_joints_seq_latents, _ = torch.max(pred_clean_joint_seq_latents.view(pred_clean_joint_seq_latents.size(0) * pred_clean_joint_seq_latents.size(1), -1).contiguous(), dim=0) # print(f"pred minn_pert_joints_latents: {minn_pert_joints_seq_latents[:10]}, pred maxx_pert_joints_seq_latents: {maxx_pert_joints_seq_latents[:10]}") ## uej pred x_stat for xx out_dict["joint_seq_output"] = model.model.dec_jts_only_fr_latents(pred_clean_joint_seq_latents)["joint_seq_output"] # pred_clean_joint_seq_latents = pert_joints_seq_latents joints_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_clean_joint_seq_latents.permute(1, 0, 2), x_t=pert_joints_seq_latents.permute(1, 0, 2), t=t ) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_mean = joints_seq_latents_mean.permute(1, 0, 2) # joints_seq_latents_mean, basejtsrel_seq_latents_mean # joints_seq_latents_variance = _extract_into_tensor(model_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joints_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, joints_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) joint_seq_output = out_dict["joint_seq_output"] jts_seq_rt_dict = { ### joints seq latents ### "joints_seq_latents_mean": joints_seq_latents_mean, "joints_seq_latents_variance": joints_seq_latents_variance, "joints_seq_latents_log_variance": joints_seq_latents_log_variance, ### decoded output values ### "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.diff_basejtsrel: if 'basejtsrel_output' in out_dict: pert_rel_base_pts_outputs = x['pert_rel_base_pts_to_rhand_joints'] # rel base pts outputs # pert_avg_joints_sequence = x['pert_avg_joints_sequence'] basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous() avg_jts_outputs = out_dict['avg_jts_outputs'] # if pert_rel_base_pts_outputs.size(0) == 1: # pert_rel_base_pts_outputs = pert_rel_base_pts_outputs.repeat(pred_basejtsrel_seq_latents.size(0), 1, 1) if self.args.pred_diff_noise: ## eps -> estimated-noises basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) avg_jts_outputs = self._predict_xstart_from_eps(pert_avg_joints_sequence, t=t, eps=avg_jts_outputs) # out_dict.update( # # model.model.dec_basejtsrel_only_fr_latents(pred_basejtsrel_seq_latents, x['input_data']) # ) ## use the predicted latents and pert_latents for the seq latents prediction ## basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t ) avg_jts_outputs_mean, _, _ = self.q_posterior_mean_variance( x_start=avg_jts_outputs, x_t=pert_avg_joints_sequence, t=t ) # basejtsrel_seq_latents_mean = basejtsrel_seq_latents_mean.permute(1, 0, 2) ## from model_variance to basejtsrel_seq_latents ### basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape) basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape) ## from model_variance to basejtsrel_seq_latents ### avg_jts_outputs_variance = _extract_into_tensor(model_variance, t, avg_jts_outputs_mean.shape) avg_jts_outputs_log_variance = _extract_into_tensor(model_log_variance, t, avg_jts_outputs_mean.shape) else: pert_rel_base_pts_outputs = x['rel_base_pts_latents'].permute(1, 0, 2) # rel base pts outputs # basejtsrel_output = out_dict['joints_denoised_latents'].permute(1, 0, 2) if self.args.pred_diff_noise: ## eps -> estimated-noises if self.args.use_var_sched: print(f"Using var_sched... in p_mean_variance") bsz = basejtsrel_output.size(0) t_item = t[0].item() alpha = self.var_sched.alphas[t_item] alpha_bar = self.var_sched.alpha_bars[t_item] sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # c0 = 1.0 / torch.sqrt(alpha) c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) # x_t = traj[t] # beta = self.var_sched.betas[[t[0].item()] * bsz] # if mask is not None: # x_t = x_t * mask # e_theta = self.net(x_t, beta=beta, context=context) z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output) basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta else: basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output) # basejtsrel_output = pert_rel_base_pts_outputs pred_joints = model.model.dec_basejtsrel_only_fr_latents(basejtsrel_output.permute(1, 0, 2), input_data)['basejtsrel_output'] ## use the predicted latents and pert_latents for the seq latents prediction ## basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t ) ## from model_variance to basejtsrel_seq_latents ### basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape) basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape) # if 'avg_jts_outputs' in out_dict: # pert_avg_joints_sequence = x['pert_avg_joints_sequence'] # avg_jts_outputs = out_dict['avg_jts_outputs'] # else: # pert_avg_joints_sequence = x['input_data']['pert_avg_joints_sequence'] # avg_jts_outputs = pert_avg_joints_sequence # if self.args.pred_diff_noise: ## eps -> estimated-noises # avg_jts_outputs = self._predict_xstart_from_eps(pert_avg_joints_sequence, t=t, eps=avg_jts_outputs) # avg_jts_outputs_mean, _, _ = self.q_posterior_mean_variance( # x_start=avg_jts_outputs, x_t=pert_avg_joints_sequence, t=t # ) # ## from model_variance to basejtsrel_seq_latents ### # avg_jts_outputs_variance = _extract_into_tensor(model_variance, t, avg_jts_outputs_mean.shape) # avg_jts_outputs_log_variance = _extract_into_tensor(model_log_variance, t, avg_jts_outputs_mean.shape) # basejtsrel_output = out_dict["basejtsrel_output"] basejtsrel_seq_rt_dict = { ### basejtsrel seq latents ### # "avg_jts_outputs": avg_jts_outputs, "basejtsrel_output_variance": basejtsrel_output_variance, "basejtsrel_output_log_variance": basejtsrel_output_log_variance, # "avg_jts_outputs_variance": avg_jts_outputs_variance, # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance, "basejtsrel_output": basejtsrel_output_mean, "basejtsrel_output_ori": basejtsrel_output, 'pred_joints': pred_joints } # if 'avg_jts_outputs' in out_dict: # basejtsrel_seq_rt_dict['avg_jts_outputs'] = out_dict['avg_jts_outputs'] else: basejtsrel_seq_rt_dict = {} if self.diff_basejtse: base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ### pred_basejtse_seq_latents = out_dict['base_jts_e_feats'] ### q-sampled latent mean here ### basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance( x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t ) basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2) basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2) # base_jts_e_feats = out_dict["base_jts_e_feats"] dec_e_along_normals = out_dict["dec_e_along_normals"] dec_e_vt_normals = out_dict["dec_e_vt_normals"] dec_d = out_dict['dec_d'] rel_vel_dec = out_dict['rel_vel_dec'] basejtse_seq_rt_dict = { ### baesjtse seq latents ### "basejtse_seq_latents_mean": basejtse_seq_latents_mean, "basejtse_seq_latents_variance": basejtse_seq_latents_variance, "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, "dec_d": dec_d, "rel_vel_dec": rel_vel_dec, } else: basejtse_seq_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_seq_rt_dict) rt_dict.update(basejtse_seq_rt_dict) return rt_dict def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( # extract into tensor # _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - _extract_into_tensor( self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape ) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: return t.float() * (1000.0 / self.num_timesteps) return t def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute the mean for the previous step, given a function cond_fn that computes the gradient of a conditional log probability with respect to x. In particular, cond_fn computes grad(log(p(y|x))), and we want to condition on y. This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, t, p_mean_var, **model_kwargs) new_mean = ( p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() ) return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, self._scale_timesteps(t), **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None): """ Compute what the p_mean_variance output would have been, should the model's score function be conditioned by cond_fn. See condition_mean() for details on cond_fn. Unlike condition_mean(), this instead uses the conditioning strategy from Song et al (2020). """ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = eps - (1 - alpha_bar).sqrt() * cond_fn( x, t, p_mean_var, **model_kwargs ) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["mean"], _, _ = self.q_posterior_mean_variance( x_start=out["pred_xstart"], x_t=x, t=t ) return out def p_sample( ## p sample ## self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, const_noise=False, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ out = self.p_mean_variance( model, x, t, # starting clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) rt_dict = {} if self.diff_jts: # bsz x ws x nnj x nnb x 3 # joints_seq_latents_noise = th.randn_like(x['joints_seq_latents']) # print('const_noise', const_noise) # seq x bsz x latent_dim # # if const_noise: # print(f"joints latents shape, ", x['joints_seq_latents'].shape) # joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1) # joints_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1))) ) # no noise when t == 0 # bsz x nseq x dim # #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... # joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2) # nseq x bsz x dim # # joints sample ## joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # #### ==== joints_seq_latents ===== #### joint_seq_output = out["joint_seq_output"] jts_seq_rt_dict = { "joints_seq_latents_sample": joints_seq_latents_sample, "joint_seq_output": joint_seq_output, } else: jts_seq_rt_dict = {} if self.diff_basejtsrel: ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### ### rel_base_pts_outputs mask ### basejtsrel_seq_latents_noise = th.randn_like(out['basejtsrel_output']) if const_noise: ## seq latents noise ## basejtsrel_seq_latents_noise = basejtsrel_seq_latents_noise[[0]].repeat(out['basejtsrel_output'].shape[0], 1, 1, 1, 1) basejtsrel_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(out['basejtsrel_output'].shape) - 1))) ) # no noise when t == 0 #### ==== basejtsrel_seq_latents ===== #### ## sample latent codes -> denoise latent codes basejtsrel_seq_latents_sample = out["basejtsrel_output"] + basejtsrel_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtsrel_output_log_variance"]) * basejtsrel_seq_latents_noise # basejtsrel_seq_latents_sample = basejtsrel_seq_latents_sample.permute(1, 0, 2) #### ==== basejtsrel_seq_latents ===== #### ##### ===== Sample for basejtsrel_seq_latents_sample ===== ##### if self.args.pred_diff_noise and self.args.use_var_sched: print(f"Using var_sched... in p_sample") basejtsrel_seq_latents_sample = out['basejtsrel_output_ori'] # avg_jts_outputs_noise = th.randn_like(x['avg_joints_sequence']) # if const_noise: ## seq latents noise ## # avg_jts_outputs_noise = basejtsrel_seq_latents_noise[[0]].repeat(x['avg_joints_sequence'].shape[0], 1, 1, 1, 1) # avg_jts_outputs_nonzero_mask = ( # (t != 0).float().view(-1, *([1] * (len(x['avg_joints_sequence'].shape) - 1))) # ) # no noise when t == 0 # avg_jts_outputs_sample = out["avg_jts_outputs"] + avg_jts_outputs_nonzero_mask * th.exp(0.5 * out["avg_jts_outputs_log_variance"]) * avg_jts_outputs_noise # basejtsrel_output = out["basejtsrel_output"] basejtsrel_rt_dict = { "basejtsrel_seq_latents_sample": basejtsrel_seq_latents_sample.permute(1, 0, 2), # "avg_jts_outputs_sample": avg_jts_outputs_sample, 'pred_joints': out['pred_joints'], } # if 'avg_jts_outputs' in out: # basejtsrel_rt_dict['avg_jts_outputs'] = out['avg_jts_outputs'] else: basejtsrel_rt_dict = {} ### basejtsrel rt dict ### ### baesjtse seq latents ### # "basejtse_seq_latents_mean": basejtse_seq_latents_mean, # "basejtse_seq_latents_variance": basejtse_seq_latents_variance, # "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance, # ### decoded output values ### # "joint_seq_output": joint_seq_output, # "basejtsrel_output": basejtsrel_output, # "dec_e_along_normals": dec_e_along_normals, # "dec_e_vt_normals": dec_e_vt_normals, if self.diff_basejtse: # e ##### ===== Sample for basejtse_seq_latents_sample ===== ##### ### rel_base_pts_outputs mask ### basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats']) # print('const_noise', const_noise) if const_noise: basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1) basejtse_seq_latents_nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1))) ) # no noise when t == 0 #### ==== basejtsrel_seq_latents ===== #### basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2) basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2) #### ==== basejtsrel_seq_latents ===== #### ##### ===== Sample for basejtse_seq_latents_sample ===== ##### dec_e_along_normals = out["dec_e_along_normals"] ## dec_e_vt_normals = out["dec_e_vt_normals"] rel_vel_dec = out["rel_vel_dec"] dec_d = out["dec_d"] basejtse_rt_dict = { "basejtse_seq_latents_sample": basejtse_seq_latents_sample, "dec_e_along_normals": dec_e_along_normals, "dec_e_vt_normals": dec_e_vt_normals, "rel_vel_dec": rel_vel_dec, "dec_d": dec_d, } else: basejtse_rt_dict = {} rt_dict.update(jts_seq_rt_dict) rt_dict.update(basejtsrel_rt_dict) rt_dict.update(basejtse_rt_dict) return rt_dict def p_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, ): """ Sample x_{t-1} from the model at the given timestep. :param model: the model to sample from. :param x: the current tensor at x_{t-1}. :param t: the value of t, starting at 0 for the first diffusion step. :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - 'sample': a random sample from the model. - 'pred_xstart': a prediction of x_0. """ with th.enable_grad(): x = x.detach().requires_grad_() out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) noise = th.randn_like(x) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 if cond_fn is not None: out["mean"] = self.condition_mean_with_grad( cond_fn, out, x, t, model_kwargs=model_kwargs ) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()} def p_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, st_timestep=None, ): ## """ Generate samples from the model. :param model: the model module. :param shape: the shape of the samples, (N, C, H, W). :param noise: if specified, the noise from the encoder to sample. Should be of the same shape as `shape`. :param clip_denoised: if True, clip x_start predictions to [-1, 1]. :param denoised_fn: if not None, a function which applies to the x_start prediction before it is used to sample. :param cond_fn: if not None, this is a gradient function that acts similarly to the model. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param device: if specified, the device to create the samples on. If not specified, use a model parameter's device. :param progress: if True, show a tqdm progress bar. :param const_noise: If True, will noise all samples with the same noise throughout sampling :return: a non-differentiable batch of samples. """ final = None # if dump_steps is not None: ## dump steps is not None ## dump = [] # function, yield, enumerate! -> for i, sample in enumerate(self.p_sample_loop_progressive( model, # p_sample # shape, # p_sample # noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, const_noise=const_noise, # the same noise # st_timestep=st_timestep, )): if dump_steps is not None and i in dump_steps: dump.append(deepcopy(sample)) final = sample if dump_steps is not None: return dump return final def p_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, const_noise=False, st_timestep=None, ): # """ Generate samples from the model and yield intermediate samples from each timestep of diffusion. Arguments are the same as p_sample_loop(). Returns a generator over dicts, where each dict is the return value of p_sample(). """ ####### ==== a conditional ssampling from init_images here!!! ==== ####### ## === give joints shape here === ## ### ==== set the shape for sampling ==== ### ### === init image sshould not be none === ### base_pts = init_image['base_pts'] base_normals = init_image['base_normals'] ## base normals ## base normals ## # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints'] # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints'] rhand_joints = init_image['rhand_joints'] vel_obj_pts_to_hand_pts = init_image['vel_obj_pts_to_hand_pts'] obj_pts_disp = init_image["obj_pts_disp"] # rhand_joints = init_image['gt_rhand_joints'] # rhand_joints = rhand_joints - ## vage for whether this model can work ### # avg_joints_sequence = std_joints_sequence = torch.std(rhand_joints, dim=1) avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: # std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # if self.args.jts_sclae_stra == "std": # and only use the latents # # joints_offset_sequence = joints_offset_sequence / std_joints_sequence joints_offset_sequence_ori = joints_offset_sequence.clone() # rhand_joints_ori = rhand_joints.clone() if self.args.jts_sclae_stra == "std": joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # if 'sampled_base_pts_nearest_obj_pc' in init_image: # ambient_init_image = { # 'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'], # 'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'], # } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals if self.args.wo_e_normalization: # init_image['per_frame_avg_disp_along_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_along_normals']) init_image['per_frame_avg_disp_vt_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_vt_normals']) init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals']) init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals']) # if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel']) # init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel']) init_image_avg_std_stats = { 'rhand_joints': init_image['rhand_joints'], 'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'], 'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'], 'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'], 'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'], } if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) # if noise is not None: # img = noise # else: # img = th.randn(*shape, device=device) ### sample progresssive ### # if skip_timesteps and init_image is None: # rhand_joints = th.zeros_like(img) # indicies indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if st_timestep is not None: ## indices indices = indices[-st_timestep: ] print(f"st_timestep: {st_timestep}, indices: {indices}") joints_scaling_factor = 5. # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # x_start['per_frame_avg_joints_rel'] = torch # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel'] if self.denoising_stra == "rep": ''' Normalization Strategy 4 ''' my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints) # # ### scale rhand joints ## # # rhand joints: bsz x ws x nnj x 3 exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) # ## avg exp rhadn joints ## if self.args.jts_sclae_stra == "std": avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True) extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True) elif self.args.jts_sclae_stra == "bbox": maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True)) else: raise ValueError(f"Unrecognized jts_sclae_stra: {self.args.jts_sclae_stra}") rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) ### Calculate moving related energies ### # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * init_image['per_frame_std_joints_rel'] + init_image['per_frame_avg_joints_rel'] denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) ## l2 real base pts k_f = 1. ## l2 rel base pts to pert rhand joints ## # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # bsz x (ws - 1) x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. ### e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # ## e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals'] ''' Normalization Strategy 4 ''' # elif self.denoising_stra == "motion_to_rep": # my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # joints_noise = th.randn_like(rhand_joints) # pert_rhand_joints = self.q_sample(rhand_joints, my_t, noise=joints_noise) # pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # # jbs zx nf x nnj x nnb x 3 # rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # dist_base_pts_to_pert_rhand_joints = torch.sum( # rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # ) else: raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}") # my_t = th.tensor([indices[-1]] * shape[0], device=device) my_t = th.tensor([indices[0]] * shape[0], device=device) # clean_joint_seq_latents = model(input_data, self._scale_timesteps(my_t)) # noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) # # pert_joint_seq_latents: bsz x seq x d # # pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() # ### Add noise to rel_baes_pts_to_rhand_joints ### # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints # ) # # bsz x ws x nnj x nnb x 3 # maxx_pert_basejtsrel, _ = torch.max(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_pert_basejtsrel, _ = torch.min(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # maxx_basejtsrel, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_basejtsrel, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # # print(f"maxx_pert_basejtsrel: {maxx_pert_basejtsrel}, minn_pert_basejtsrel: {minn_pert_basejtsrel}, maxx_basejtsrel: {maxx_basejtsrel}, minn_basejtsrel: {minn_basejtsrel}") # ### Add noise to # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) # pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... # avg_joints_sequence, my_t, noise_avg_joints_sequence # ) # pert_avg_joints_sequence = avg_joints_sequence # # joints_offset_sequence # noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) # print(f"my_t: {my_t}") # pert_joints_offset_sequence = self.q_sample( # joints_offset_sequence, my_t, noise_joints_offset_sequence # ) pert_joints_offset_sequence = joints_offset_sequence # sv_pert_dict = { # 'joints_offset_sequence': joints_offset_sequence.detach().cpu().numpy(), # 'pert_joints_offset_sequence': pert_joints_offset_sequence.detach().cpu().numpy(), # 'noise_joints_offset_sequence': noise_joints_offset_sequence.detach().cpu().numpy(), # 'joints_offset_sequence_ori': joints_offset_sequence_ori.detach().cpu().numpy(), # 'rhand_joints_ori': rhand_joints.detach().cpu().numpy(), # } # sv_pert_dict_fn = "tot_pert_jts_sequence_dict.npy" # this file @!!!!! # np.save(sv_pert_dict_fn, sv_pert_dict) # print(f"pert data saved to {sv_pert_dict_fn} !!!!") # if self.args.rnd_noise: # pert_joints_offset_sequence = noise_joints_offset_sequence # pert_avg_joints_sequence = noise_avg_joints_sequence # tot_pert_joint = pert_joints_offset_sequence * std_joints_sequence.unsqueeze(1) + pert_avg_joints_sequence.unsqueeze(1) # np.save("tot_pert_joint.npy", tot_pert_joint.detach().cpu().numpy()) # denoised es ? # input_data = { 'base_pts': base_pts, 'base_normals': base_normals, # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints, # 'pert_rhand_joints': pert_normed_rhand_joints, # 'pert_rhand_joints': pert_scaled_rhand_joints, 'rhand_joints': rhand_joints, # # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # 'avg_joints_sequence': avg_joints_sequence, # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, } # if 'sampled_base_pts_nearest_obj_pc' in init_image: ## init iamge ## # input_data.update(ambient_init_image) input_data.update( { 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, 'vel_obj_pts_to_hand_pts': vel_obj_pts_to_hand_pts, 'obj_pts_disp': obj_pts_disp } ) # input input_data.update(init_image_avg_std_stats) input_data['rhand_joints'] = rhand_joints # normed model_kwargs = { k: val for k, val in init_image.items() if k not in input_data } if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) # dec_in_dict = latent_out_dict = model(input_data, my_t) # seq x bsz x dim # rel_base_pts_latents = latent_out_dict['rel_base_pts_latents'] if 'rel_base_pts_latents_mean' in latent_out_dict: rel_base_pts_latents = latent_out_dict['rel_base_pts_latents_mean'] maxx_latents, _ = torch.max(rel_base_pts_latents.view(rel_base_pts_latents.size(0) * rel_base_pts_latents.size(1), -1), dim=0) minn_latents, _ = torch.min(rel_base_pts_latents.view(rel_base_pts_latents.size(0) * rel_base_pts_latents.size(1), -1), dim=0) print(f"max latentss: {maxx_latents[:10]}, minn_latents: {minn_latents[:10]}") noise_rel_base_pts_latents = torch.randn_like(rel_base_pts_latents) pert_rel_base_pts_latents = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... rel_base_pts_latents, my_t, noise_rel_base_pts_latents ) if self.args.rnd_noise: pert_rel_base_pts_latents = noise_rel_base_pts_latents # pert_rel_base_pts_latents = rel_base_pts_latents pert_maxx_latents, _ = torch.max(pert_rel_base_pts_latents.view(pert_rel_base_pts_latents.size(0) * pert_rel_base_pts_latents.size(1), -1), dim=0) pert_minn_latents, _ = torch.min(pert_rel_base_pts_latents.view(pert_rel_base_pts_latents.size(0) * pert_rel_base_pts_latents.size(1), -1), dim=0) print(f"pert_maxx_latents: {pert_maxx_latents[:10]}, pert_minn_latents: {pert_minn_latents[:10]}") # pert_avg_joints_sequence = avg_joints_sequence dec_in_dict = { 'rel_base_pts_latents': pert_rel_base_pts_latents, 'input_data': input_data, } for i_idx, i in enumerate(indices): t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, # size of y device=model_kwargs['y'].device) # device of y with th.no_grad(): # inter_optim # progress # # p_sample_with_grad ## sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample out = sample_fn( model, dec_in_dict, ## sample from input data ## t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, const_noise=const_noise, ) if self.diff_basejtsrel: ## seq latents sample ## basejtsrel_seq_latents_sample = out["basejtsrel_seq_latents_sample"] # if self.args.pred_joints_offset: # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3 # sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1) sampled_rhand_joints = out['pred_joints'] * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1) basejtsrel_seq_dec_in_dict = { # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ## # 'pert_rel_base_pts_to_rhand_joints': out["basejtsrel_seq_latents_sample"], ## pert realtive base pts to rhand joints ## 'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # 'pert_joints_offset_sequence': out["basejtsrel_seq_latents_sample"], } input_data.update(basejtsrel_seq_dec_in_dict) basejtsrel_dec_in_dict = { 'rel_base_pts_latents': out["basejtsrel_seq_latents_sample"], 'input_data': input_data } else: basejtsrel_seq_input_dict = {} basejtsrel_seq_dec_in_dict = {} dec_in_dict.update(basejtsrel_dec_in_dict) yield input_data def ddim_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. ##; hand position and relative noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]} def ddim_sample_with_grad( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. Same usage as p_sample(). """ with th.enable_grad(): x = x.detach().requires_grad_() out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig out["pred_xstart"] = out["pred_xstart"].detach() # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = th.randn_like(x) mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) ) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()} ## def ddim_reverse_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. """ assert eta == 0.0, "Reverse ODE only for deterministic path" out = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps ) return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} def ddim_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, dump_steps=None, const_noise=False, ): """ Generate samples from the model using DDIM. Same usage as p_sample_loop(). """ if dump_steps is not None: raise NotImplementedError() if const_noise == True: raise NotImplementedError() final = None for sample in self.ddim_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, eta=eta, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, ): final = sample return final["sample"] def ddim_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, eta=0.0, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, ): """ Use DDIM to sample from the model and yield intermediate samples from each timestep of DDIM. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample out = sample_fn( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, eta=eta, ) yield out img = out["sample"] def plms_sample( self, model, x, t, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, cond_fn_with_grad=False, order=2, old_out=None, ): """ Sample x_{t-1} from the model using Pseudo Linear Multistep. Same usage as p_sample(). """ if not int(order) or not 1 <= order <= 4: raise ValueError('order is invalid (should be int from 1-4).') def get_model_output(x, t): with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None): x = x.detach().requires_grad_() if cond_fn_with_grad else x out_orig = self.p_mean_variance( model, x, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) if cond_fn is not None: if cond_fn_with_grad: out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) x = x.detach() else: out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs) else: out = out_orig # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) return eps, out, out_orig alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) eps, out, out_orig = get_model_output(x, t) if order > 1 and old_out is None: # Pseudo Improved Euler old_eps = [eps] mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps eps_2, _, _ = get_model_output(mean_pred, t - 1) eps_prime = (eps + eps_2) / 2 pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime else: # Pseudo Linear Multistep (Adams-Bashforth) old_eps = old_out["old_eps"] old_eps.append(eps) cur_order = min(order, len(old_eps)) if cur_order == 1: eps_prime = old_eps[-1] elif cur_order == 2: eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2 elif cur_order == 3: eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12 elif cur_order == 4: eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24 else: raise RuntimeError('cur_order is invalid.') pred_prime = self._predict_xstart_from_eps(x, t, eps_prime) mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime if len(old_eps) >= order: old_eps.pop(0) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask) return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps} def plms_sample_loop( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Generate samples from the model using Pseudo Linear Multistep. Same usage as p_sample_loop(). """ final = None for sample in self.plms_sample_loop_progressive( model, shape, noise=noise, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, device=device, progress=progress, skip_timesteps=skip_timesteps, init_image=init_image, randomize_class=randomize_class, cond_fn_with_grad=cond_fn_with_grad, order=order, ): final = sample return final["sample"] def plms_sample_loop_progressive( self, model, shape, noise=None, clip_denoised=True, denoised_fn=None, cond_fn=None, model_kwargs=None, device=None, progress=False, skip_timesteps=0, init_image=None, randomize_class=False, cond_fn_with_grad=False, order=2, ): """ Use PLMS to sample from the model and yield intermediate samples from each timestep of PLMS. Same usage as p_sample_loop_progressive(). """ if device is None: device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise else: img = th.randn(*shape, device=device) if skip_timesteps and init_image is None: init_image = th.zeros_like(img) indices = list(range(self.num_timesteps - skip_timesteps))[::-1] if init_image is not None: my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] img = self.q_sample(init_image, my_t, img) if progress: # Lazy import so that we don't depend on tqdm. from tqdm.auto import tqdm indices = tqdm(indices) old_out = None for i in indices: t = th.tensor([i] * shape[0], device=device) if randomize_class and 'y' in model_kwargs: model_kwargs['y'] = th.randint(low=0, high=model.num_classes, size=model_kwargs['y'].shape, device=model_kwargs['y'].device) with th.no_grad(): out = self.plms_sample( model, img, t, clip_denoised=clip_denoised, denoised_fn=denoised_fn, cond_fn=cond_fn, model_kwargs=model_kwargs, cond_fn_with_grad=cond_fn_with_grad, order=order, old_out=old_out, ) yield out old_out = out img = out["sample"] def _vb_terms_bpd( self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. The resulting units are bits (rather than nats, as one might expect). This allows for comparison to other papers. :return: a dict with the following keys: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t ) out = self.p_mean_variance( model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs ) kl = normal_kl( true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] ) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] ) assert decoder_nll.shape == x_start.shape decoder_nll = mean_flat(decoder_nll) / np.log(2.0) # At the first timestep return the decoder NLL, # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} ## ## training losses ## ## training losses ## def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # training losses # training losses for rel/dist representations ## Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ if self.args.train_diff: # set enc to evals # # print(f"Setitng encoders to eval mode") model.model.set_enc_to_eval() enc = model.model ## model.model mask = model_kwargs['y']['mask'] ## rot2xyz get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, glob=enc.glob, ## rot2xyz; ## # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP jointstype='smpl', # 3.4 iter/sec vertstrans=False) # bsz x ws x nnj x 3 # # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 # # bsz x ws x nnjts x 3 # rhand_joints = x_start['rhand_joints'] # bsz x nnbase x 3 # base_pts = x_start['base_pts'] # bsz x ws x nnbase x 3 # base_normals = x_start['base_normals'] avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ## std_joints_sequence = torch.std(rhand_joints, dim=1) if self.args.joint_std_v2: std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) elif self.args.joint_std_v3: # std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1) joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # if self.args.jts_sclae_stra == "std": joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1) else: std_joints_sequence = torch.ones_like(std_joints_sequence) # # bsz x ws x nnjts x nnbase x 3 # # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints'] # # bsz x ws x nnjts x nnbase # # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints'] ## no diffu # normalization strategy for joints and that for the representation values # if 'sampled_base_pts_nearest_obj_pc' in x_start: ambient_xstart_dict = { 'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'], 'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'], } # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals if self.args.wo_e_normalization: x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals']) x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals']) x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals']) x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals']) if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel']) x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel']) ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## ## base pts to rhand joints ## # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## ## relative joint positions ### rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ## x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # x_start['per_frame_avg_joints_rel'] = torch # bsz x ws x nnj x nnb x 3 # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel'] # rel_base_pts_to_rhand_joints : bsz x ws x nnj x nnb x 3 ## # data normalization; # construct statistics, normalize values # joints_scaling_factor = 5. ''' GET rel and dists ''' ## rep and rhand_joints ##### if self.denoising_stra == "rep": # bsz x ws x nnj x nnb x 3 # # avg_jts: 1 x nnj x 3 # std_jts: 1 x nnj x 3 # rhand joints # rhand_joints: bsz x ws x nnj x 3; normalize rhand joitns # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device) noise_rhand_joints = th.randn_like(normed_rhand_joints) pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, t, noise=noise_rhand_joints) # ### scale rhand joints ## # rhand joints: bsz x ws x nnj x 3 ## each joint 1 x 3 -> normalization # exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3) if self.args.jts_sclae_stra == "std": avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True) extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True) elif self.args.jts_sclae_stra == "bbox": ### bounding box ### maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True) minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True) avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints # extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True)) ### bounding box ### else: raise ValueError(f"Unrecognized jts_scale_str: {self.args.jts_sclae_stra}") ## rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1) scaled_rhand_joints = rhand_joints * joints_scaling_factor noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints) pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, t, noise=noise_scaled_rhand_joints) # pert_rhand_joints: bsz x nnj x 3 ## -> pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # and ### Calculate moving related energies ### # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ## denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel'] denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum( denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) ## l2 real base pts k_f = 1. ## l2 rel base pts to pert rhand joints ## # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1) ### att_forces ## att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb # # bsz x (ws - 1) x nnj x nnb # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ## # bsz x (ws - 1) x nnj x 3 --> displacements s# denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :] # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum( base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1 ) # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2) - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1) dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum( rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1 )) k_a = 1. k_b = 1. ### e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal) # (ws - 1) x nnj x nnb # -> dist vt normals # ## e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ## # # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals", x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size()) # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals'] e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals'] elif self.denoising_stra == "motion_to_rep": # print(f"Using denoising stra: {self.denoising_stra}") joints_noise = torch.randn_like(rhand_joints) pert_rhand_joints = self.q_sample(rhand_joints, t, noise=joints_noise) # q_sample for the noisy joints # pert_rhand_joints: bsz x nf x nnj x 3 ## --> pert joints # base_pts: bsz x nnb x 3 # avg jts and std jts ## pert_rhand_joints_denorm = pert_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device) # jbs zx nf x nnj x nnb x 3 rel_base_pts_to_pert_rhand_joints = pert_rhand_joints_denorm.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) dist_base_pts_to_pert_rhand_joints = torch.sum( rel_base_pts_to_pert_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 ) else: raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}") ''' GET rel and dists ''' ### Add noise to rel_baes_pts_to_rhand_joints ### noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### rel_base_pts_to_rhand_joints, t, noise_rel_base_pts_to_rhand_joints ) # bsz x ws x nnj x nnb x 3 # maxx_pert_basejtsrel, _ = torch.max(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_pert_basejtsrel, _ = torch.min(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # maxx_basejtsrel, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # minn_basejtsrel, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0) # # print(f"maxx_pert_basejtsrel: {maxx_pert_basejtsrel}, minn_pert_basejtsrel: {minn_pert_basejtsrel}, maxx_basejtsrel: {maxx_basejtsrel}, minn_basejtsrel: {minn_basejtsrel}") ### Add noise to noise_avg_joints_sequence = th.randn_like(avg_joints_sequence) pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... avg_joints_sequence, t, noise_avg_joints_sequence ) # joints_offset_sequence # noise_joints_offset_sequence = th.randn_like(joints_offset_sequence) # pert_joints_offset_sequence = self.q_sample( # joints_offset_sequence, t, noise_joints_offset_sequence # ) pert_joints_offset_sequence = joints_offset_sequence input_data = { 'base_pts': base_pts.clone(), # base pts ### 'base_normals': base_normals.clone(), # base normals ### # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints.clone(), ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## -> 1) encode to the latent space; 2) add noise in the latent space; 3) denoise latent codes; 4) use denoised latent codes for further prediction ### 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(), # 'pert_rhand_joints': pert_normed_rhand_joints, # scaled_rhand_joints, pert_scaled_rhand_joints 'pert_rhand_joints': pert_scaled_rhand_joints, # 'rhand_joints': rhand_joints, 'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ## 'pert_avg_joints_sequence': pert_avg_joints_sequence, 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ## 'pert_joints_offset_sequence': pert_joints_offset_sequence, } if 'sampled_base_pts_nearest_obj_pc' in x_start: input_data.update(ambient_xstart_dict) # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # bsz x ws - 1 x nnj x nnb # # input_data input_data.update( { # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals ### e_disp_rel_to_base_along_normals: 'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, ## clean values # # the denoising space is then transformed to the latent space; noisy inputs -> latent code --> can we really denoise them correctly } ) # input_data.update( # {k: x_start[k].clone() for k in x_start if k not in input_data} # ) # gaussian diffusion ours ## # rel_base_pts_to_rhand_joints in the input_data # if model_kwargs is None: model_kwargs = {} terms = {} # latents in the latent space # # sequence latents # # if self.args.train_diff: # with torch.no_grad(): # out_dict = model(input_data, self._scale_timesteps(t).clone()) # else: # clean_joint_seq_latents: seq x bs x d # # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## ### the strategy of removing noise from corresponding quantities ### out_dict = model(input_data, self._scale_timesteps(t).clone()) ### get model output dictionary ### KL_loss = 0. terms['rot_mse'] = 0. ### diff_jts ### # out dict of the # # reumse checkpoints #dec_in_dict dec_in_dict = {} if self.diff_jts: ### Sample for perturbed joints seq latents ### clean_joint_seq_latents = out_dict["joint_seq_output"] noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents) if self.args.const_noise: noise_joint_seq_latents = noise_joint_seq_latents[0].unsqueeze(0).repeat(noise_joint_seq_latents.size(0), 1, 1).contiguous() pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() ### Sample for perturbed joints seq latents ### dec_in_dict['joints_seq_latents'] = pert_joint_seq_latents dec_in_dict['joints_seq_latents_enc'] = clean_joint_seq_latents if self.args.kl_weights > 0. and "joint_seq_output_mean" in out_dict and not self.args.train_diff: # clean_joint_seq_latents: seq_len x bs x d # log_p_joints_seq = model_util.standard_normal_logprob(clean_joint_seq_latents) log_p_joints_seq = log_p_joints_seq.permute(1, 0, 2).contiguous() # log_p_joints_seq = log_p_joints_seq.sum(dim=-1).mean(dim=-1) # log_p_joints_seq entropy_joints_seq = model_util.gaussian_entropy(out_dict['joint_seq_output_logvar'].permute(1, 2, 0)).mean(dim=-1) loss_prior_joints_seq = (- log_p_joints_seq - entropy_joints_seq) KL_loss += loss_prior_joints_seq # the dimension of latents ? ## # log_pz = standard_normal_logprob(z).sum(dim=1) # (B, ), Independence assumption; # entropy = gaussian_entropy(logvar=z_sigma) # (B, ) ### reparameterize gaussian... ### # ### sum over feature sapce ### # loss_prior = (-log_pz - entropy).mean() ############# only need x to be the target shape? ########### if self.diff_basejtsrel: # noise_rel_base_pts_latents, pert_joints_offset_sequence # joints_offset_sequence rel_base_pts_latents = out_dict['rel_base_pts_latents'] if self.args.train_enc: pert_joints_offset_sequence = rel_base_pts_latents else: noise_rel_base_pts_latents = th.randn_like(rel_base_pts_latents.permute(1, 0, 2)) pert_joints_offset_sequence = self.q_sample( rel_base_pts_latents.detach().permute(1, 0, 2), t, noise_rel_base_pts_latents ) pert_joints_offset_sequence = pert_joints_offset_sequence.permute(1, 0, 2) if self.args.use_vae and self.args.kl_weights > 0.: rel_base_pts_latents_mean = out_dict['rel_base_pts_latents_mean'] rel_base_pts_latents_logvar = out_dict['rel_base_pts_latents_logvar'] log_p_rel_base_pts_seq = model_util.standard_normal_logprob(rel_base_pts_latents) log_p_rel_base_pts_seq = log_p_rel_base_pts_seq.permute(1, 0, 2).contiguous() # log_p_rel_base_pts_seq = log_p_rel_base_pts_seq.sum(dim=-1).mean(dim=-1) # log_p_joints_seq entropy_rel_base_pts_seq = model_util.gaussian_entropy(rel_base_pts_latents_logvar.permute(1, 2, 0)).mean(dim=-1) loss_prior_rel_base_pts_seq = (- log_p_rel_base_pts_seq - entropy_rel_base_pts_seq) KL_loss += loss_prior_rel_base_pts_seq basejtsrel_dec_in_dict = { 'rel_base_pts_latents': pert_joints_offset_sequence, 'rel_base_pts_latents_enc': rel_base_pts_latents # vae space } # # pert_basejtsrel_seq_latents = self.q_sample(basejtsrel_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_basejtsrel_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() # basejtsrel_seq_latents = basejtsrel_seq_latents # .detach() # pert_basejtsrel_seq_latents = self.q_sample(basejtsrel_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_basejtsrel_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous() # ### Sample for perturbed basejtsrel seq latents ### # dec_in_dict['rel_base_pts_outputs'] = pert_basejtsrel_seq_latents.detach() # dec_in_dict['rel_base_pts_outputs_enc'] = basejtsrel_seq_latents.detach() # if self.args.kl_weights > 0. and "rel_base_pts_outputs_mean" in out_dict and not self.args.train_diff: # # clean_joint_seq_latents: seq_len x bs x d # # log_p_rel_base_pts_seq = model_util.standard_normal_logprob(basejtsrel_seq_latents) # log_p_rel_base_pts_seq = log_p_rel_base_pts_seq.permute(1, 0, 2).contiguous() # # log_p_rel_base_pts_seq = log_p_rel_base_pts_seq.sum(dim=-1).mean(dim=-1) # # log_p_joints_seq # entropy_rel_base_pts_seq = model_util.gaussian_entropy(out_dict['rel_base_pts_outputs_logvar'].permute(1, 2, 0)).mean(dim=-1) # loss_prior_rel_base_pts_seq = (- log_p_rel_base_pts_seq - entropy_rel_base_pts_seq) # KL_loss += loss_prior_rel_base_pts_seq dec_in_dict.update(basejtsrel_dec_in_dict) out_dict = model.model.dec_latents_to_joints_with_t(dec_in_dict, input_data, t) # noise_rel_base_pts_latents, pert_joints_offset_sequence # joints_offset_sequence if self.diff_basejtsrel: joints_denoised_latents = out_dict['joints_denoised_latents'] joints_offset_output = out_dict['joints_offset_output'] # bsz x seq x nnjts x 3 if self.args.train_enc: pred_joints_loss = torch.sum( (joints_offset_sequence - joints_offset_output) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) denoising_latents_loss = torch.zeros_like(pred_joints_loss) else: if self.args.pred_diff_noise: denoising_latents_loss = torch.sum( (joints_denoised_latents.permute(1, 0, 2) - noise_rel_base_pts_latents) ** 2, dim=-1 ).mean(dim=-1) else: denoising_latents_loss = torch.sum( # denoising latents (joints_denoised_latents.permute(1, 0, 2) - rel_base_pts_latents.permute(1, 0, 2).detach()) ** 2, dim=-1 ).mean(dim=-1) joints_offset_output = model.model.dec_basejtsrel_only_fr_latents( rel_base_pts_latents, input_data)['basejtsrel_output'] pred_joints_loss = torch.sum( (joints_offset_sequence - joints_offset_output) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1) terms['jts_pred_loss'] = pred_joints_loss terms['jts_latent_denoising_loss'] = denoising_latents_loss terms['rot_mse'] += pred_joints_loss + denoising_latents_loss if self.args.use_vae and self.args.kl_weights > 0.: terms['KL_loss'] = KL_loss # KL loss ## terms['rot_mse'] += KL_loss * self.args.kl_weights # if self.diff_jts: # dec_clean_joint_seq = dec_out_dict["joint_seq_output"] # dec_clena_seq_latents = dec_out_dict["joints_seq_latents"] # ### joints latents and decoding prediction ### # jts_pred_loss = torch.sum( # (rhand_joints - dec_clean_joint_seq) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1) # if self.args.pred_diff_noise: # # noise_joint_seq_latents # jts_latent_denoising_loss = (torch.sum( # (dec_clena_seq_latents.permute(1, 0, 2).contiguous() - noise_joint_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / dec_clena_seq_latents.size(-1)).mean(dim=-1) # else: # jts_latent_denoising_loss = (torch.sum( # (dec_clena_seq_latents.permute(1, 0, 2).contiguous() - clean_joint_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / dec_clena_seq_latents.size(-1)).mean(dim=-1) # if self.args.train_enc: # jts_latent_denoising_loss = torch.zeros_like(jts_latent_denoising_loss) # if self.args.train_diff: # train_diff # jts_pred_loss = torch.zeros_like(jts_pred_loss) # terms['jts_pred_loss'] = jts_pred_loss # terms['jts_latent_denoising_loss'] = jts_latent_denoising_loss # ## train diff or train enc --> enc # # jts_pred_loss_coeff # # basejtsrel_pred_loss_coeff, # # basejtse_along_normal_loss_coeff, basejtse_vt_normal_loss_coeff # # terms['rot_mse'] += jts_pred_loss * 20 + jts_latent_denoising_loss # terms['rot_mse'] += jts_pred_loss * self.args.jts_pred_loss_coeff + jts_latent_denoising_loss # if self.diff_basejtsrel: # dec_basejtsrel = dec_out_dict["basejtsrel_output"] # dec_clean_basejtsrel_seq_latents = dec_out_dict["basejtsrel_seq_latents"] # # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 # # ## energeis # # basejtsrel_latent_denoising_loss = (torch.sum( # # (dec_clean_basejtsrel_seq_latents.permute(1, 0, 2).contiguous() - basejtsrel_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # # ) / basejtsrel_seq_latents.size(-1)).mean(dim=-1) ### # dec_basejtsrel_latenets = dec_clean_basejtsrel_seq_latents # if not self.args.train_enc and self.args.pred_diff_noise: ## pre diff noise ## # # noise_joint_seq_latents # basejtsrel_latent_denoising_loss = (torch.sum( # (dec_clean_basejtsrel_seq_latents.permute(1, 0, 2).contiguous() - noise_basejtsrel_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / noise_basejtsrel_seq_latents.size(-1)).mean(dim=-1) # dec_basejtsrel_latenets = self._predict_xstart_from_eps(pert_basejtsrel_seq_latents.permute(1, 0, 2), t=t, eps=noise_basejtsrel_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # # dec_basejtsrel_latenets = self._predict_xstart_from_eps(pert_basejtsrel_seq_latents.permute(1, 0, 2), t=t, eps=dec_clean_basejtsrel_seq_latents.permute(1, 0, 2)).permute(1, 0, 2) # dec_out_dict.update( # model.model.dec_basejtsrel_only_fr_latents(dec_basejtsrel_latenets, input_data) # ) # dec_basejtsrel = dec_out_dict["basejtsrel_output"] # else: # basejtsrel_latent_denoising_loss = (torch.sum( # (dec_clean_basejtsrel_seq_latents.permute(1, 0, 2).contiguous() - basejtsrel_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / basejtsrel_seq_latents.size(-1)).mean(dim=-1) # basejtsrel_pred_loss = torch.sum( # (rel_base_pts_to_rhand_joints.permute(0, 1, 3, 2, 4).contiguous() - dec_basejtsrel) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # ### Avg-jts-outputs ### # if 'avg_jts_outputs' in dec_out_dict and not (self.args.not_pred_avg_jts): # dec_avg_jts_sequence = dec_out_dict['avg_jts_outputs'] # gt_avg_joints_sequence = avg_joints_sequence # avg_joints_pred_loss = torch.sum( # (gt_avg_joints_sequence - dec_avg_jts_sequence) ** 2, dim=-1 # ).mean(dim=-1) # else: # avg_joints_pred_loss = None # if self.args.train_enc: # train enc # basejtsrel_latent_denoising_loss = torch.zeros_like(basejtsrel_latent_denoising_loss) # if self.args.train_diff: # train_diff # basejtsrel_pred_loss = torch.zeros_like(basejtsrel_pred_loss) # if avg_joints_pred_loss is not None: # avg_joints_pred_loss = torch.zeros_like(avg_joints_pred_loss) # terms['basejtsrel_pred_loss'] = basejtsrel_pred_loss # terms['basejtsrel_latent_denoising_loss'] = basejtsrel_latent_denoising_loss # if avg_joints_pred_loss is not None: # terms['avg_joints_pred_loss'] = avg_joints_pred_loss # # terms['rot_mse'] += basejtsrel_pred_loss * 20 + basejtsrel_latent_denoising_loss # terms['rot_mse'] += basejtsrel_pred_loss * self.args.basejtsrel_pred_loss_coeff + basejtsrel_latent_denoising_loss # if avg_joints_pred_loss is not None: # terms['rot_mse'] += avg_joints_pred_loss # if self.diff_basejtse: # dec_base_jts_e_feats = dec_out_dict['base_jts_e_feats'] # dec_e_along_normals = dec_out_dict['dec_e_along_normals'] # dec_e_vt_normals = dec_out_dict['dec_e_vt_normals'] # # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals # # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 # # basejtse_along_normals_pred_loss = torch.sum( # (e_disp_rel_to_base_along_normals.unsqueeze(-1) - dec_e_along_normals.unsqueeze(-1)) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # basejtse_vt_normals_pred_loss = torch.sum( # (e_disp_rel_to_baes_vt_normals.unsqueeze(-1) - dec_e_vt_normals.unsqueeze(-1)) ** 2, dim=-1 # ).mean(dim=-1).mean(dim=-1).mean(dim=-1) # # basejtse_latent_denoising_loss = (torch.sum( # # (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # # ) / basejtse_seq_latents.size(-1)).mean(dim=-1) # if self.args.pred_diff_noise: # # noise_joint_seq_latents # basejtse_latent_denoising_loss = (torch.sum( # (basejtse_seq_latents.permute(1, 0, 2).contiguous() - noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / basejtse_seq_latents.size(-1)).mean(dim=-1) # else: # basejtse_latent_denoising_loss = (torch.sum( # (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1 # ) / basejtse_seq_latents.size(-1)).mean(dim=-1) # # find out kong # # if self.args.train_enc: # basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss) # if self.args.train_diff: # train_diff # no basejtse denoising losses ## # # basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss) # basejtse_along_normals_pred_loss = torch.zeros_like(basejtse_along_normals_pred_loss) # basejtse_vt_normals_pred_loss = torch.zeros_like(basejtse_vt_normals_pred_loss) # terms['basejtse_along_normals_pred_loss'] = basejtse_along_normals_pred_loss # terms['basejtse_vt_normals_pred_loss'] = basejtse_vt_normals_pred_loss # terms['basejtse_latent_denoising_loss'] = basejtse_latent_denoising_loss # # terms['rot_mse'] += basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss # terms['rot_mse'] += basejtse_along_normals_pred_loss * self.args.basejtse_along_normal_loss_coeff + basejtse_vt_normals_pred_loss * self.args.basejtse_vt_normal_loss_coeff + basejtse_latent_denoising_loss # if self.args.kl_weights > 0. and not self.args.train_diff: # terms['KL_loss'] = KL_loss # terms['rot_mse'] += KL_loss * self.args.kl_weights # sv_inter_dict = { # 'dec_joints': dec_clean_joint_seq.detach().cpu().numpy(), # 'rhand_joints': rhand_joints.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } # # sv_inter_dict_fn = os.path.join(args.save_dir, ) ### construct final loss ### # terms['rot_mse'] = jts_pred_loss * 20 + jts_latent_denoising_loss + basejtsrel_pred_loss * 20 + basejtsrel_latent_denoising_loss + basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss # ### === only use joints-only losses === ### # terms['rot_mse'] = jts_pred_loss * 20 + jts_latent_denoising_loss ### === only use base-jts-rel losses === ### # terms['rot_mse'] = basejtsrel_pred_loss * 20 + basejtsrel_latent_denoising_loss ### === only use base-jts-e losses === ### # terms['rot_mse'] = basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss # terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), # 'target': target.detach().cpu().numpy(), # 't': t.detach().cpu().numpy(), # } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) # sv_dir_rt = "/data1/sim/motion-diffusion-model" # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") # np.save(sv_out_fn,sv_inter_dict ) # print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None if self.lambda_rcxyz > 0. and dataset.dataname not in ['motion_ours']: print(f"Calculating lambda_rcxyz!!!") target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) self.lambda_vel = 0. if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) # else: # raise NotImplementedError(self.loss_type) return terms ## training losses def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None): """ # s ## predict sa Compute training losses for a single timestep. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param t: a batch of timestep indices. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :param noise: if specified, the specific Gaussian noise to try to remove. :return: a dict with the key "loss" containing a tensor of shape [N]. Some mean or variance settings may also have other keys. """ # enc = model.model._modules['module'] # enc = model.model mask = model_kwargs['y']['mask'] ## rot2xyz; # # ### avg_joints, std_joints ### # if 'avg_joints' in model_kwargs['y']: avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1) std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1) else: avg_joints = None std_joints = None # ### avg_joints, std_joints ### # # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation, # glob=enc.glob, ## rot2xyz; ## # # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP # jointstype='smpl', # 3.4 iter/sec # vertstrans=False) if model_kwargs is None: ## predict single steps --> model_kwargs = {} if noise is None: noise = th.randn_like(x_start) x_t = self.q_sample(x_start, t, noise=noise) # randn_like for x_start, t, x_t --> get x_t from x_start # # how we control the tiem stamp t? terms = {} if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: terms["loss"] = self._vb_terms_bpd( ## vb terms bpd # model=model, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, model_kwargs=model_kwargs, )["output"] if self.loss_type == LossType.RESCALED_KL: terms["loss"] *= self.num_timesteps elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) # model_output ---> model x_t # if self.model_var_type in [ ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: # s B, C = x_t.shape[:2] assert model_output.shape == (B, C * 2, *x_t.shape[2:]) model_output, model_var_values = th.split(model_output, C, dim=1) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) terms["vb"] = self._vb_terms_bpd( model=lambda *args, r=frozen_out: r, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, )["output"] if self.loss_type == LossType.RESCALED_MSE: # Divide by 1000 for equivalence with initial implementation. # Without a factor of 1/1000, the VB term hurts the MSE term. terms["vb"] *= self.num_timesteps / 1000.0 target = { # q posterior mean variance # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( x_start=x_start, x_t=x_t, t=t )[0], ModelMeanType.START_X: x_start, ModelMeanType.EPSILON: noise, }[self.model_mean_type] # model mean type --> mean type # assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes] terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse) # model_output, target, t # ### avg_joints, std_joints ### # if avg_joints is not None: print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}") print(f"Denormalizing joints...") model_output = (model_output * std_joints) + avg_joints target = (target * std_joints) + avg_joints sv_out_in = { # 'model_output': model_output.detach().cpu().numpy(), 'target': target.detach().cpu().numpy(), 't': t.detach().cpu().numpy(), } import os import datetime cur_time_stamp = datetime.datetime.now().timestamp() cur_time_stamp = str(cur_time_stamp) sv_dir_rt = "/data1/sim/motion-diffusion-model" sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy") np.save(sv_out_fn,sv_out_in ) print(f"Samples saved to {sv_out_fn}") target_xyz, model_output_xyz = None, None self.lambda_rcxyz = 0. if self.lambda_rcxyz > 0.: target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes] model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes] terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2) if self.lambda_vel_rcxyz > 0.: if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1]) model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1]) terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:]) if self.lambda_fc > 0.: ## lambda fc ## torch.autograd.set_detect_anomaly(True) if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']: target_xyz = get_xyz(target) if target_xyz is None else target_xyz model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11 relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx] gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1) pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames] pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1] pred_vel[~fc_mask] = 0 terms["fc"] = self.masked_l2(pred_vel, torch.zeros(pred_vel.shape, device=pred_vel.device), mask[:, :, :, 1:]) if self.lambda_vel > 0.: target_vel = (target[..., 1:] - target[..., :-1]) model_output_vel = (model_output[..., 1:] - model_output[..., :-1]) terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location! model_output_vel[:, :-1, :, :], mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2) terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\ (self.lambda_vel * terms.get('vel_mse', 0.)) +\ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \ (self.lambda_fc * terms.get('fc', 0.)) else: raise NotImplementedError(self.loss_type) return terms, model_output, target, t def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask): def to_np_cpu(x): return x.detach().cpu().numpy() """ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames] """ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11 l_ankle_idx, r_ankle_idx = 7, 8 l_foot_idx, r_foot_idx = 10, 11 """ Contact calculated by 'Kfir Method' Commented code)""" # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2] # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames] # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :]) # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames] # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames] # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1) # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2] # left_z_mask[:, :, 1] = False # Blank right side # contact_signal[left_z_mask] = 0.4 # # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1) # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2] # right_z_mask[:, :, 0] = False # Blank left side # contact_signal[right_z_mask] = 0.4 # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1 # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1 # plt.plot(to_np_cpu(left_z[0]), label='left_z') # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc') # plt.grid() # plt.legend() # plt.show() # plt.plot(to_np_cpu(right_z[0]), label='right_z') # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity') # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc') # plt.grid() # plt.legend() # plt.show() gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] fc_mask = (gt_joint_vel <= 0.01) pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames] pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames] pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES] pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2) """DEBUG CODE""" # print(f'mask: {mask.shape}') # print(f'pred_joint_vel: {pred_joint_vel.shape}') # plt.title(f'Joint: {joint_idx}') # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity') # plt.plot(to_np_cpu(fc_mask[0]), label='fc') # plt.grid() # plt.legend() # plt.show() return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device), mask[:, :, :, 1:]) # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def foot_contact_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11] vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :] vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :] vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :] vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :] calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1] calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1] calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1] calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1] # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1) for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip( [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11], [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], [7, 10, 8, 11], [0, 1, 2, 3]): tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int) chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0) chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)), axis=0) print(tmp_mask_gt.shape) print(chosen_vel_foot.shape) print(chosen_vel_calc_norm.shape) import matplotlib.pyplot as plt plt.plot(tmp_mask_gt, label='FC mask') plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)') plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)') plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}') plt.legend() plt.show() # print(vel_foots.shape) return 0 # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE! def velocity_consistency_loss_humanml3d(self, target, model_output): # root_rot_velocity (B, seq_len, 1) # root_linear_velocity (B, seq_len, 2) # root_y (B, seq_len, 1) # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D # local_velocity (B, seq_len, joint_num*3) , XYZ # foot contact (B, seq_len, 4) , target_fc = target[:, -4:, :, :] root_rot_velocity = target[:, :1, :, :] root_linear_velocity = target[:, 1:3, :, :] root_y = target[:, 3:4, :, :] ric_data = target[:, 4:67, :, :] # 4+(3*21)=67 rot_data = target[:, 67:193, :, :] # 67+(6*21)=193 local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259 contact = target[:, 259:, :, :] # 193+(3*22)=259 calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1] velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor)) print(f'r_rot_quat: {r_rot_quat.shape}') print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}') calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor) r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}') calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz) calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3)) calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2) print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}') import matplotlib.pyplot as plt for i in range(21): plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel') plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel') plt.title(f'Joint idx: {i}') plt.legend() plt.show() print(calc_vel_from_xyz.shape) print(velocity_from_vector.shape) diff = calc_vel_from_xyz-velocity_from_vector print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0)) return 0 def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): """ Compute the entire variational lower-bound, measured in bits-per-dim, as well as other related quantities. :param model: the model to evaluate loss on. :param x_start: the [N x C x ...] tensor of inputs. :param clip_denoised: if True, clip denoised samples. :param model_kwargs: if not None, a dict of extra keyword arguments to pass to the model. This can be used for conditioning. :return: a dict containing the following keys: - total_bpd: the total variational lower-bound, per batch element. - prior_bpd: the prior term in the lower-bound. - vb: an [N x T] tensor of terms in the lower-bound. - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - mse: an [N x T] tensor of epsilon MSEs for each timestep. """ device = x_start.device batch_size = x_start.shape[0] vb = [] xstart_mse = [] mse = [] for t in list(range(self.num_timesteps))[::-1]: t_batch = th.tensor([t] * batch_size, device=device) noise = th.randn_like(x_start) x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) # Calculate VLB term at the current timestep with th.no_grad(): out = self._vb_terms_bpd( model, x_start=x_start, x_t=x_t, t=t_batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs, ) vb.append(out["output"]) xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) mse.append(mean_flat((eps - noise) ** 2)) vb = th.stack(vb, dim=1) xstart_mse = th.stack(xstart_mse, dim=1) mse = th.stack(mse, dim=1) prior_bpd = self._prior_bpd(x_start) total_bpd = vb.sum(dim=1) + prior_bpd return { "total_bpd": total_bpd, "prior_bpd": prior_bpd, "vb": vb, "xstart_mse": xstart_mse, "mse": mse, } def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape)