import torch as th from .generic_sampler import batch_mul def get_x0_grad_pred_fn(raw_net_model, cond_loss_fn, weight_fn, x0_update, thres_t): def fn(xt, scalar_t): xt = xt.requires_grad_(True) x0_pred = raw_net_model(xt, scalar_t) loss_info = { "raw_x0": cond_loss_fn(x0_pred.detach()).cpu(), } traj_info = { "t": scalar_t, } if scalar_t < thres_t: x0_cor = x0_pred.detach() else: pred_loss = cond_loss_fn(x0_pred) grad_term = th.autograd.grad(pred_loss.sum(), xt)[0] weights = weight_fn(x0_pred, grad_term, cond_loss_fn) x0_cor = (x0_pred - batch_mul(weights, grad_term)).detach() loss_info["weight"] = weights.detach().cpu() traj_info["grad"] = grad_term.detach().cpu() if x0_update: x0 = x0_update(x0_cor, scalar_t) else: x0 = x0_cor loss_info["cor_x0"] = cond_loss_fn(x0_cor.detach()).cpu() loss_info["x0"] = cond_loss_fn(x0.detach()).cpu() traj_info.update({ "raw_x0": x0_pred.detach().cpu(), "cor_x0": x0_cor.detach().cpu(), "x0": x0.detach().cpu(), } ) return x0_cor, loss_info, traj_info return fn