Ata Celen
Model Weights added
dcaa3ad
raw
history blame
74.3 kB
diff --git a/config/locomotion.py b/config/locomotion.py
deleted file mode 100644
index 4410bb1..0000000
--- a/config/locomotion.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import socket
-
-from diffuser.utils import watch
-
-#------------------------ base ------------------------#
-
-## automatically make experiment names for planning
-## by labelling folders with these args
-
-diffusion_args_to_watch = [
- ('prefix', ''),
- ('horizon', 'H'),
- ('n_diffusion_steps', 'T'),
-]
-
-base = {
- 'diffusion': {
- ## model
- 'model': 'models.TemporalUnet',
- 'diffusion': 'models.GaussianDiffusion',
- 'horizon': 32,
- 'n_diffusion_steps': 100,
- 'action_weight': 10,
- 'loss_weights': None,
- 'loss_discount': 1,
- 'predict_epsilon': False,
- 'dim_mults': (1, 4, 8),
- 'renderer': 'utils.MuJoCoRenderer',
-
- ## dataset
- 'loader': 'datasets.SequenceDataset',
- 'normalizer': 'LimitsNormalizer',
- 'preprocess_fns': [],
- 'clip_denoised': True,
- 'use_padding': True,
- 'max_path_length': 1000,
-
- ## serialization
- 'logbase': 'logs',
- 'prefix': 'diffusion/',
- 'exp_name': watch(diffusion_args_to_watch),
-
- ## training
- 'n_steps_per_epoch': 10000,
- 'loss_type': 'l2',
- 'n_train_steps': 1e6,
- 'batch_size': 32,
- 'learning_rate': 2e-4,
- 'gradient_accumulate_every': 2,
- 'ema_decay': 0.995,
- 'save_freq': 1000,
- 'sample_freq': 1000,
- 'n_saves': 5,
- 'save_parallel': False,
- 'n_reference': 8,
- 'n_samples': 2,
- 'bucket': None,
- 'device': 'cuda',
- },
-}
-
-#------------------------ overrides ------------------------#
-
-## put environment-specific overrides here
-
-halfcheetah_medium_expert_v2 = {
- 'diffusion': {
- 'horizon': 16,
- },
-}
diff --git a/config/maze2d.py b/config/maze2d.py
index a06ac7f..0a8d22a 100644
--- a/config/maze2d.py
+++ b/config/maze2d.py
@@ -34,11 +34,11 @@ base = {
'model': 'models.TemporalUnet',
'diffusion': 'models.GaussianDiffusion',
'horizon': 256,
- 'n_diffusion_steps': 256,
+ 'n_diffusion_steps': 512,
'action_weight': 1,
'loss_weights': None,
'loss_discount': 1,
- 'predict_epsilon': False,
+ 'predict_epsilon': True,
'dim_mults': (1, 4, 8),
'renderer': 'utils.Maze2dRenderer',
@@ -57,14 +57,14 @@ base = {
'exp_name': watch(diffusion_args_to_watch),
## training
- 'n_steps_per_epoch': 10000,
- 'loss_type': 'l2',
- 'n_train_steps': 2e6,
- 'batch_size': 32,
- 'learning_rate': 2e-4,
- 'gradient_accumulate_every': 2,
+ 'n_steps_per_epoch': 60000,
+ 'loss_type': 'spline',
+ 'n_train_steps': 6e4,
+ 'batch_size': 1,
+ 'learning_rate': 5e-6,
+ 'gradient_accumulate_every': 8,
'ema_decay': 0.995,
- 'save_freq': 1000,
+ 'save_freq': 2000,
'sample_freq': 1000,
'n_saves': 50,
'save_parallel': False,
@@ -89,7 +89,6 @@ base = {
'prefix': 'plans/release',
'exp_name': watch(plan_args_to_watch),
'suffix': '0',
-
'conditional': False,
## loading
@@ -122,10 +121,10 @@ maze2d_umaze_v1 = {
maze2d_large_v1 = {
'diffusion': {
'horizon': 384,
- 'n_diffusion_steps': 256,
+ 'n_diffusion_steps': 16,
},
'plan': {
'horizon': 384,
- 'n_diffusion_steps': 256,
+ 'n_diffusion_steps': 16,
},
}
diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py
index 1ad2106..5991f01 100644
--- a/diffuser/datasets/buffer.py
+++ b/diffuser/datasets/buffer.py
@@ -9,7 +9,7 @@ class ReplayBuffer:
def __init__(self, max_n_episodes, max_path_length, termination_penalty):
self._dict = {
- 'path_lengths': np.zeros(max_n_episodes, dtype=np.int),
+ 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_),
}
self._count = 0
self.max_n_episodes = max_n_episodes
diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py
index 356c540..73c1b04 100644
--- a/diffuser/datasets/sequence.py
+++ b/diffuser/datasets/sequence.py
@@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset):
actions = self.fields.normed_actions[path_ind, start:end]
conditions = self.get_conditions(observations)
+
trajectories = np.concatenate([actions, observations], axis=-1)
batch = Batch(trajectories, conditions)
return batch
diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py
index fae4cfd..461680a 100644
--- a/diffuser/models/diffusion.py
+++ b/diffuser/models/diffusion.py
@@ -2,6 +2,7 @@ import numpy as np
import torch
from torch import nn
import pdb
+import matplotlib.pyplot as plt
import diffuser.utils as utils
from .helpers import (
@@ -9,6 +10,7 @@ from .helpers import (
extract,
apply_conditioning,
Losses,
+ catmull_rom_spline_with_rotation,
)
class GaussianDiffusion(nn.Module):
@@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module):
betas = cosine_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
+ print(f"Alphas Cumprod: {alphas_cumprod}")
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = int(n_timesteps)
@@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module):
'''
self.action_weight = action_weight
- dim_weights = torch.ones(self.transition_dim, dtype=torch.float32)
+ dim_weights = torch.ones(self.transition_dim, dtype=torch.float64)
## set loss coefficients for dimensions of observation
if weights_dict is None: weights_dict = {}
@@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module):
otherwise, model predicts x0 directly
'''
if self.predict_epsilon:
- return (
- extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
+ return noise
else:
return noise
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:]
)
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
@@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module):
def p_sample(self, x, cond, t):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
- noise = torch.randn_like(x)
+ noise = torch.randn_like(x[:, :, self.action_dim:])
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module):
device = self.betas.device
batch_size = shape[0]
- x = torch.randn(shape, device=device)
- x = apply_conditioning(x, cond, self.action_dim)
+ # x = torch.randn(shape, device=device, dtype=torch.float64)
+ # Extract known indices and values
+ known_indices = np.array(list(cond.keys()), dtype=int)
+
+ # candidate_no x batch_size x dim
+ known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
+ known_values = np.moveaxis(known_values, 0, 1)
+
+ # Sort the timepoints
+ sorted_indices = np.argsort(known_indices)
+ known_indices = known_indices[sorted_indices]
+ known_values = known_values[:, sorted_indices]
+
+ # Build the structured spline guess
+ catmull_spline_trajectory = np.array([
+ catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1])
+ for b in range(batch_size)
+ ])
+ catmull_spline_trajectory = torch.tensor(
+ catmull_spline_trajectory,
+ dtype=torch.float64,
+ device=device
+ )
+
+
+ if self.predict_epsilon:
+ x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64)
+ cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
+ is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64)
+ is_cond[:, known_indices, :] = 1.0
if return_diffusion: diffusion = [x]
- progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
+ # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent()
for i in reversed(range(0, self.n_timesteps)):
+ if self.predict_epsilon:
+ x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1)
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
- x = self.p_sample(x, cond, timesteps)
- x = apply_conditioning(x, cond, self.action_dim)
+ x = self.p_sample(x, cond_residual, timesteps)
+
+ x = apply_conditioning(x, cond_residual, 0)
- progress.update({'t': i})
+ if return_diffusion: diffusion.append(x)
- if return_diffusion: diffusion.append(x)
+ x = catmull_spline_trajectory + x
- progress.close()
+
+
+ # Normalize the quaternions
+ # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True)
+
+ # progress.close()
if return_diffusion:
return x, torch.stack(diffusion, dim=1)
@@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module):
conditions : [ (time, state), ... ]
'''
device = self.betas.device
- batch_size = len(cond[0])
+ batch_size = len(next(iter(cond.values())))
horizon = horizon or self.horizon
shape = (batch_size, horizon, self.transition_dim)
@@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module):
#------------------------------------------ training ------------------------------------------#
- def q_sample(self, x_start, t, noise=None):
+ def q_sample(self, x_start, t, spline=None, noise=None):
+ x_start_noise = x_start[:, : , :-1]
+ x_start_is_cond = x_start[:, :, [-1]]
+
+ if spline is None:
+ spline = torch.randn_like(x_start_noise)
if noise is None:
- noise = torch.randn_like(x_start)
+ noise = torch.randn_like(x_start_noise)
- sample = (
- extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
- )
+ alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
+ oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+
+ # Weighted combination of x_0 and the spline
+ out = alpha * x_start_noise + oneminusalpha * noise
+
+ # Concatenate the binary feature and the spline as the conditioning
+ out = torch.cat([spline, x_start_is_cond, out], dim=-1)
- return sample
+ return out
def p_losses(self, x_start, cond, t):
- noise = torch.randn_like(x_start)
+ batch_size, horizon, _ = x_start.shape
+ # Extract known indices and values
+ known_indices = np.array(list(cond.keys()), dtype=int)
+
+ # candidate_no x batch_size x dim
+ known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0)
+ known_values = np.moveaxis(known_values, 0, 1)
+
+ # Sort the timepoints
+ sorted_indices = np.argsort(known_indices)
+ known_indices = known_indices[sorted_indices]
+ known_values = known_values[:, sorted_indices]
+
+ # Build your structured guess
+ catmull_spline_trajectory = np.array([
+ catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon)
+ for b in range(batch_size)
+ ])
+ catmull_spline_trajectory = torch.tensor(
+ catmull_spline_trajectory,
+ dtype=torch.float64,
+ device=x_start.device
+ )
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
+ # Plot the quaternions
+ # plt.plot(x_start[0, :, 3].cpu().numpy())
+ # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy())
+ # plt.legend(["x_start", "catmull_spline"])
+ # plt.show()
+ # raise Exception
- x_recon = self.model(x_noisy, cond, t)
- x_recon = apply_conditioning(x_recon, cond, self.action_dim)
- assert noise.shape == x_recon.shape
+ if not self.predict_epsilon:
+ # Forward diffuse with the structured trajectory
+ x_noisy = self.q_sample(
+ x_start,
+ t,
+ spline=catmull_spline_trajectory,
+ )
+ x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
- if self.predict_epsilon:
- loss, info = self.loss_fn(x_recon, noise)
+ # Reverse pass guess
+ x_recon = self.model(x_noisy, cond, t)
+ x_recon = apply_conditioning(x_recon, cond, self.action_dim)
+
+ # Then x_recon is the predicted x_0, compare to the true x_0
+ loss, info = self.loss_fn(x_recon, x_start, cond)
else:
- loss, info = self.loss_fn(x_recon, x_start)
+ residual = x_start.clone()
+
+ residual[:, :, :-1] -= catmull_spline_trajectory
+
+
+ cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()}
+
+ x_noisy = self.q_sample(
+ residual,
+ t,
+ spline=catmull_spline_trajectory,
+ )
+ x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim)
+
+ # Reverse pass guess
+ x_recon = self.model(x_noisy, cond, t)
+ x_recon = apply_conditioning(x_recon, cond_residual, 0)
+
+ x_recon = x_recon + catmull_spline_trajectory
+
+ loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond)
return loss, info
def loss(self, x, cond):
batch_size = len(x)
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
+ # t = torch.randint(1, 2, (batch_size,), device=x.device).long()
+ # x = x.double()
+ # cond = {k: v.double() for k, v in cond.items()}
+ # print(f"Time: {t.item()}")
return self.p_losses(x, cond, t)
def forward(self, cond, *args, **kwargs):
diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py
index d39f35d..9f43ef8 100644
--- a/diffuser/models/helpers.py
+++ b/diffuser/models/helpers.py
@@ -1,11 +1,11 @@
import math
+import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import einops
from einops.layers.torch import Rearrange
-import pdb
+from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle
import diffuser.utils as utils
@@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module):
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
- self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64)
def forward(self, x):
return self.conv(x)
@@ -38,7 +38,7 @@ class Downsample1d(nn.Module):
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
- self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64)
def forward(self, x):
return self.conv(x)
@@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module):
super().__init__()
self.block = nn.Sequential(
- nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
+ nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
- nn.GroupNorm(n_groups, out_channels),
+ nn.GroupNorm(n_groups, out_channels).to(torch.float64),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
@@ -72,7 +72,7 @@ def extract(a, t, x_shape):
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
-def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
+def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
@@ -157,9 +157,979 @@ class ValueL2(ValueLoss):
def _loss(self, pred, targ):
return F.mse_loss(pred, targ, reduction='none')
+class GeodesicL2Loss(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ pass
+
+ def _loss(self, pred, targ):
+ # Compute L2 loss for the first three dimensions
+ l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean')
+
+ # Normalize to unit quaternions for the last four dimensions
+ pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True)
+ targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True)
+
+ assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs"
+ assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs"
+
+ # Compute dot product for the quaternions
+ dot_product = torch.sum(pred_quat * targ_quat, dim=-1)
+ dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0)
+
+ # Compute geodesic loss for the quaternions
+ geodesic_loss = 2 * torch.acos(dot_product).mean()
+
+ assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs"
+ assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs"
+
+ return l2_loss + geodesic_loss, l2_loss, geodesic_loss
+
+ def forward(self, pred, targ):
+ loss, l2, geodesic = self._loss(pred, targ)
+
+ info = {
+ 'l2': l2.item(),
+ 'geodesic': geodesic.item(),
+ }
+
+ return loss, info
+
+class RotationTranslationLoss(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ pass
+
+ def _loss(self, pred, targ, cond=None):
+
+ # Make sure the dtype is float64
+ pred = pred.to(torch.float64)
+ targ = targ.to(torch.float64)
+
+ eps = 1e-8
+
+ pred_trans = pred[..., :3]
+ pred_quat = pred[..., 3:7]
+ targ_trans = targ[..., :3]
+ targ_quat = targ[..., 3:7]
+
+ l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean')
+
+ # Calculate the geodesic loss
+ pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
+ targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps)
+
+ pred_quat_norm = pred_quat / pred_n
+ targ_quat_norm = targ_quat / targ_n
+
+
+ dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
+ quaternion_dist = 1 - (dot_product ** 2).mean()
+
+ # Calculate the rotation error
+ pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3)
+ targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3)
+
+ r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
+ trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
+ trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
+ geodesic_loss = torch.acos(trace).mean()
+
+ # Add a smoothness and acceleration term to the positions and quaternions
+ alpha = 1.0
+ smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean')
+ acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean')
+
+ l2_multiplier = 10.0
+
+ loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss)
+
+ dtw = DynamicTimeWarpingLoss()
+ dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
+
+ hausdorff = HausdorffDistanceLoss()
+ hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
+
+ frec = FrechetDistanceLoss()
+ frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
+
+ chamfer = ChamferDistanceLoss()
+ chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3))
+
+ return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
+
+
+ def forward(self, pred, targ, cond=None):
+ loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond)
+
+ info = {
+ 'rot. error': err_r.item(),
+ 'geodesic error': err_geo.item(),
+ 'trans. error': err_t.item(),
+ 'dtw': err_dtw.item(),
+ 'hausdorff': err_hausdorff.item(),
+ 'frechet': err_frechet.item(),
+ 'chamfer': err_chamfer.item(),
+ }
+
+ return loss, info
+
+class SplineLoss(nn.Module):
+ def __init__(self, *args):
+ super().__init__()
+ self.scales = json.load(open('scene_scale.json'))
+
+ def compute_spline_coeffs(self, trans):
+ p0 = trans[:, :-3, :]
+ p1 = trans[:, 1:-2, :]
+ p2 = trans[:, 2:-1, :]
+ p3 = trans[:, 3:, :]
+
+ # Tangent approximations
+ m1 = 0.5 * (-p0 + p2)
+ m2 = 0.5 * (-p1 + p3)
+
+ # Cubic spline coefficients for each dimension
+ a = (2 * p1 - 2 * p2 + m1 + m2)
+ b = (-3 * p1 + 3 * p2 - 2 * m1 - m2)
+ c = (m1)
+ d = (p1)
+
+ return torch.stack([a, b, c, d], dim=-1)
+
+ def q_normalize(self, q):
+ return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12)
+
+ def q_conjugate(self, q):
+ w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
+ return torch.stack([w, -x, -y, -z], dim=-1)
+
+ def q_multiply(self, q1, q2):
+ """
+ q1*q2.
+ """
+ w1, x1, y1, z1 = q1.unbind(-1)
+ w2, x2, y2, z2 = q2.unbind(-1)
+ w = w1*w2 - x1*x2 - y1*y2 - z1*z2
+ x = w1*x2 + x1*w2 + y1*z2 - z1*y2
+ y = w1*y2 - x1*z2 + y1*w2 + z1*x2
+ z = w1*z2 + x1*y2 - y1*x2 + z1*w2
+ return torch.stack([w, x, y, z], dim=-1)
+
+ def q_inverse(self, q):
+ return self.q_conjugate(self.q_normalize(q))
+
+ def q_log(self, q):
+ """
+ Quaternion logarithm for a unit quaternion
+ Only returns the imaginary part
+ """
+ q = self.q_normalize(q)
+ w = q[..., 0]
+ xyz = q[..., 1:] # shape [..., 3]
+ mag_v = xyz.norm(p=2, dim=-1)
+ eps = 1e-12
+ angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps))
+
+ # We do a safe-guard against zero for sin(angle)
+ small_mask = (mag_v < 1e-12) | (angle < 1e-12)
+ # Where small_mask is True => near identity => log(q) ~ 0
+ log_val = torch.zeros_like(xyz)
+
+ # Normal case
+ scale = angle / mag_v.clamp(min=1e-12)
+ normal_case = scale.unsqueeze(-1) * xyz
+
+ log_val = torch.where(
+ small_mask.unsqueeze(-1),
+ torch.zeros_like(xyz),
+ normal_case
+ )
+ return log_val
+
+ def q_exp(self, v):
+ """
+ Quaternion exponential
+ """
+ norm_v = v.norm(p=2, dim=-1)
+ small_mask = norm_v < 1e-12
+
+ w = torch.cos(norm_v)
+ sin_v = torch.sin(norm_v)
+ scale = torch.where(
+ small_mask,
+ torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0
+ sin_v / norm_v.clamp(min=1e-12)
+ )
+ xyz = scale.unsqueeze(-1) * v
+
+ # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1
+ w = torch.where(
+ small_mask,
+ torch.ones_like(w),
+ w
+ )
+ return torch.cat([w.unsqueeze(-1), xyz], dim=-1)
+
+ def q_slerp(self, q1, q2, t):
+ """
+ Spherical linear interpolation from q1 to q2 at t in [0,1].
+ Both q1, q2 assumed normalized.
+ q1, q2, t can be 1D or broadcastable shapes, but typically 1D.
+ """
+ q1 = self.q_normalize(q1)
+ q2 = self.q_normalize(q2)
+ dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product
+
+ eps = 1e-12
+ dot = dot.clamp(-1.0 + eps, 1.0 - eps)
+
+ flip_mask = dot < 0.0
+ if flip_mask.any():
+ q2 = torch.where(flip_mask, -q2, q2)
+ dot = torch.where(flip_mask, -dot, dot)
+
+ # If they're very close, do a simple linear interpolation
+ close_mask = dot.squeeze(-1) > 0.9995
+ # Using an epsilon to avoid potential issues close to 1.0
+
+ # Branch 1: Very close
+ # linear LERP
+ lerp_val = (1.0 - t) * q1 + t * q2
+ lerp_val = self.q_normalize(lerp_val)
+
+ # Branch 2: Standard SLERP
+ theta_0 = torch.acos(dot)
+ sin_theta_0 = torch.sin(theta_0)
+ theta = theta_0 * t
+ s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12)
+ s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12)
+ slerp_val = s1 * q1 + s2 * q2
+ slerp_val = self.q_normalize(slerp_val)
+
+ # Combine
+ return torch.where(
+ close_mask.unsqueeze(-1),
+ lerp_val,
+ slerp_val
+ )
+
+ def compute_uniform_tangent(self, q_im1, q_i, q_ip1):
+ """
+ Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i,
+ given neighbors q_im1, q_i, q_ip1.
+
+ T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] )
+ """
+ q_im1 = self.q_normalize(q_im1)
+ q_i = self.q_normalize(q_i)
+ q_ip1 = self.q_normalize(q_ip1)
+
+ inv_qi = self.q_inverse(q_i)
+ r1 = self.q_multiply(inv_qi, q_ip1)
+ r2 = self.q_multiply(inv_qi, q_im1)
+
+ lr1 = self.q_log(r1)
+ lr2 = self.q_log(r2)
+
+ m = -0.25 * (lr1 + lr2)
+ exp_m = self.q_exp(m)
+ return self.q_multiply(q_i, exp_m)
+
+ def compute_all_uniform_tangents(self, quats):
+ """
+ Vectorized version that computes tangents T_i for all keyframe quaternions at once.
+ quats shape: [N,4], N >= 2
+ Returns shape [N,4].
+ """
+ q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0
+ q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1}
+
+ return self.compute_uniform_tangent(q_im1, quats, q_ip1)
+
+ def squad(self, q0, a, b, q1, t):
+ """
+ Shoemake's "squad" interpolation for quaternion splines:
+ squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t),
+ slerp(a, b; t),
+ 2t(1-t) )
+ where a, b are tangential control quaternions for q0, q1.
+ """
+ s1 = self.q_slerp(q0, q1, t)
+ s2 = self.q_slerp(a, b, t)
+ alpha = 2.0*t*(1.0 - t)
+ return self.q_slerp(s1, s2, alpha)
+
+ def uniform_cr_spline(self, quats, num_samples_per_segment=10):
+ """
+ Given a list of keyframe quaternions quats (each a torch 1D tensor [4]),
+ compute a "Uniform Catmull–Rom–like" quaternion spline through them.
+
+ Returns:
+ A list (Python list) of interpolated quaternions (torch tensors),
+ including all segment endpoints.
+
+ Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}.
+ For boundary tangents, we replicate the end quaternions.
+ """
+ n = quats.shape[0]
+ if n < 2:
+ return quats.unsqueeze(0) # not enough quats to interpolate
+
+ # Precompute tangents
+ tangents = self.compute_all_uniform_tangents(quats)
+
+ # Interpolate each segment [qi, q_{i+1}]
+ q0 = quats[:-1].unsqueeze(1)
+ q1 = quats[1:].unsqueeze(1)
+ a = tangents[:-1].unsqueeze(1)
+ b = tangents[1:].unsqueeze(1)
+
+ t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype)
+ t_vals = t_vals.view(1, -1, 1)
+
+ out = self.squad(q0, a, b, q1, t_vals)
+ return out
+
+
+ def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None):
+ loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params)
+
+ info = {
+ 'trans. error': err_t.item(),
+ 'smoothness error': err_smooth.item(),
+ # 'dtw': err_dtw.item(),
+ # 'hausdorff': err_hausdorff.item(),
+ # 'frechet': err_frechet.item(),
+ # 'chamfer': err_chamfer.item(),
+ 'quat. dist.': err_r.item(),
+ 'geodesic dist.': err_geo.item(),
+ }
+
+ return loss, info
+
+ def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None):
+ def poly_eval(coeffs, x):
+ """
+ Evaluates a polynomial (with highest-degree term first) at points x.
+ coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first.
+ x: 1D tensor of points at which to evaluate the polynomial.
+ Returns:
+ 2D tensor of shape [num_polynomials, len(x)], containing p(x).
+ """
+ x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1)
+ x_powers = x_powers.to(torch.float64).to(coeffs.device)
+ y = torch.matmul(coeffs, x_powers.T)
+ return y
+
+ # Make sure the dtype is float64
+ pred = pred.to(torch.float64)
+ targ = targ.to(torch.float64)
+
+ # Rescale the translations
+ if scene_id is not None and norm_params is not None:
+ scene_id = scene_id.item()
+ scene_scale = self.scales[str(scene_id)]
+ scene_scale = norm_params['scale'][0] * scene_scale
+ pred[..., :3] = pred[..., :3] * scene_scale
+ targ[..., :3] = targ[..., :3] * scene_scale
+ # print(pred[..., :3].max(), targ[..., :3].max())
+
+ # We only consider interpolated points for loss calculation
+ candidate_idxs = sorted(cond.keys())
+ pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
+ targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :]
+
+ pred_trans = pred[..., :3]
+ pred_quat = pred[..., 3:7]
+ targ_trans = targ[..., :3]
+ targ_quat = targ[..., 3:7]
+
+ pred_coeffs = self.compute_spline_coeffs(pred_trans)
+ targ_coeffs = self.compute_spline_coeffs(targ_trans)
+
+ n_points = 2000
+
+ # Distribute sample points among intervals
+ dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1)
+ dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device)
+ for i in range(len(candidate_idxs) - 1):
+ dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum()
+
+ weights_c = dists_c / dists_c.sum()
+ scaled_c = weights_c * n_points
+ points_c = torch.floor(scaled_c).int()
+
+ while points_c.sum() < n_points:
+ idx = torch.argmax(scaled_c - points_c)
+ points_c[idx] += 1
+
+ # Calculate the spline loss
+ sample_points = 50
+ x = torch.linspace(0, 1, sample_points, device=pred.device)
+ pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
+ targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3)
+
+ indexes = []
+ start_idx = candidate_idxs[0]
+ for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])):
+ p = points_c[c]
+ total_dist = dists_c[c]
+ dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx]
+
+ step_distances = (dist_arr / sample_points).repeat_interleave(sample_points)
+ cumul_distances = step_distances.cumsum(dim=0)
+
+ dist_per_pick = total_dist / p
+ pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick
+
+ pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True)
+ pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1)
+
+
+ indexes_1d = torch.zeros_like(step_distances)
+ indexes_1d[pick_idxs] = 1
+
+ indexes_2d = indexes_1d.view(len(dist_arr), sample_points)
+
+ indexes.append(indexes_2d)
+
+ indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations
+
+ indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1)
+ indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1)
+
+ indexes_trans = indexes_trans.to(torch.bool)
+ indexes_quat = indexes_quat.to(torch.bool)
+
+ pred_trans_selected_values = pred_spline[indexes_trans]
+ targ_trans_selected_values = targ_spline[indexes_trans]
+
+ pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3)
+ targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3)
+
+ # Calculate the loss for quaternions
+ pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+
+ targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
+ pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points)
+
+
+ targ_quat_spline = targ_quat_spline[1:-1]
+ pred_quat_spline = pred_quat_spline[1:-1]
+
+
+ pred_quat_selected_values = pred_quat_spline[indexes_quat]
+ targ_quat_selected_values = targ_quat_spline[indexes_quat]
+
+ pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4)
+ targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4)
+
+ # Calculate the geodesic loss
+ pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3)
+ targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3)
+
+ eps = 1e-12
+ r2r1 = pred_rot @ targ_rot.permute(0, 2, 1)
+ trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1)
+ trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps)
+ geodesic_loss = torch.acos(trace).mean()
+
+ # Calculate the rotation error
+ dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
+ quaternion_dist = 1 - (dot_product ** 2).mean()
+
+ # Calculate the L2 loss
+ l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean')
+
+ # Calculate the smoothness loss for translation and quaternion
+ smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss
+ weight_acceleration = 0.1
+ weight_jerk = 0.05
+
+ pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :]
+ pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :]
+
+ pos_acceleration_loss = torch.mean(pos_acc ** 2)
+ pos_jerk_loss = torch.mean(pos_jerk ** 2)
+
+ q0 = pred_quat_selected_values[:-1, :]
+ q1 = pred_quat_selected_values[1:, :]
+ sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0)
+ q1 = sign.unsqueeze(-1) * q1
+
+ dq = self.q_multiply(q1, self.q_inverse(q0))
+ theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8))
+
+ rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2]
+ rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3]
+
+ rot_acceleration_loss = torch.mean(rot_acc ** 2)
+ rot_jerk_loss = torch.mean(rot_jerk ** 2)
+
+ alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10)
+
+
+ acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss
+ jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss
+
+ smoothness_loss = (
+ weight_acceleration * acceleration_loss
+ + weight_jerk * jerk_loss
+ ) * smoothness_multiplier
+
+
+ # Calculate the spline loss
+ l2_multiplier = 10.0
+ spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist
+
+ dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None
+
+ # Uncomment these lines if you want to use the other losses
+ '''
+ dtw = DynamicTimeWarpingLoss()
+ dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
+
+ hausdorff = HausdorffDistanceLoss()
+ hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
+
+ frec = FrechetDistanceLoss()
+ frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
+
+ chamfer = ChamferDistanceLoss()
+ chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3))
+ '''
+
+ return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss
+
+
+class DynamicTimeWarpingLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the DTW distance between two 2D tensors (T x D),
+ where T is sequence length and D is feature dimension.
+ """
+ # seq1, seq2 shapes: (time_steps, feature_dim)
+ n, m = seq1.size(0), seq2.size(0)
+
+ # Cost matrix (pairwise distances between all elements)
+ cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
+ for i in range(n):
+ for j in range(m):
+ cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
+
+ # Accumulated cost matrix
+ dist = torch.full((n + 1, m + 1), float('inf'),
+ device=seq1.device, dtype=seq1.dtype)
+ dist[0, 0] = 0.0
+
+ # Populate the DP table
+ for i in range(1, n + 1):
+ for j in range(1, m + 1):
+ dist[i, j] = cost[i - 1, j - 1] + torch.min(
+ torch.min(
+ dist[i - 1, j], # Insertion
+ dist[i, j - 1], # Deletion
+ ),
+ dist[i - 1, j - 1]# Match
+ )
+
+ return dist[n, m]
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the average DTW loss over a batch of sequences.
+
+ pred, targ shapes: (batch_size, T, D)
+ """
+ # Ensure shapes match in batch dimension
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
+
+ # Compute DTW distance per sample in the batch
+ distances = []
+ for b in range(pred.size(0)):
+ seq1 = pred[b]
+ seq2 = targ[b]
+ dtw_val = self._dtw_distance(seq1, seq2)
+ distances.append(dtw_val)
+
+ # Stack and take mean to get scalar loss
+ dtw_loss = torch.stack(distances).mean()
+ return dtw_loss
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
+ """
+ Returns a tuple: (loss, info_dict),
+ where loss is a scalar tensor and info_dict is a dictionary
+ of extra information (e.g., loss components).
+ """
+ loss = self._loss(pred, targ)
+
+ info = {
+ 'dtw': loss.item()
+ }
+
+ return loss, info
+
+class HausdorffDistanceLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the Hausdorff distance between two 2D tensors (N x D),
+ where N is the number of points and D is the feature dimension.
+
+ The Hausdorff distance H(A,B) between two sets A and B is defined as:
+ H(A, B) = max( h(A, B), h(B, A) ),
+ where
+ h(A, B) = max_{a in A} min_{b in B} d(a, b).
+
+ Here, d(a, b) is the Euclidean distance between points a and b.
+ """
+ # set1, set2 shapes: (num_points, feature_dim)
+ n, m = set1.size(0), set2.size(0)
+
+ # Compute pairwise distances
+ cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
+ for i in range(n):
+ for j in range(m):
+ cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
+
+ # Forward direction: for each point in set1, find distance to closest point in set2
+ forward_min = cost.min(dim=1)[0] # Shape (n,)
+ forward_hausdorff = forward_min.max() # max over n
+
+ # Backward direction: for each point in set2, find distance to closest point in set1
+ backward_min = cost.min(dim=0)[0] # Shape (m,)
+ backward_hausdorff = backward_min.max() # max over m
+
+ # Hausdorff distance is the max of the two
+ hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff)
+ return hausdorff_dist
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the average Hausdorff distance over a batch of point sets.
+
+ pred, targ shapes: (batch_size, N, D)
+ """
+ # Ensure shapes match in batch dimension
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
+
+ distances = []
+ for b in range(pred.size(0)):
+ set1 = pred[b]
+ set2 = targ[b]
+ h_dist = self._hausdorff_distance(set1, set2)
+ distances.append(h_dist)
+
+ # Stack and take mean to get scalar loss
+ hausdorff_loss = torch.stack(distances).mean()
+ return hausdorff_loss
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
+ """
+ Returns a tuple: (loss, info_dict),
+ where loss is a scalar tensor and info_dict is a dictionary
+ of extra information (e.g., distance components).
+ """
+ loss = self._loss(pred, targ)
+
+ info = {
+ 'hausdorff': loss.item()
+ }
+
+ return loss, info
+
+class FrechetDistanceLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the (discrete) Fréchet distance between two 2D tensors (T x D),
+ where T is the sequence length and D is the feature dimension.
+
+ The Fréchet distance between two curves in discrete form can be computed
+ by filling in a DP table “ca” where:
+
+ ca[i, j] = max( d(seq1[i], seq2[j]),
+ min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) )
+
+ with boundary conditions handled appropriately.
+ Here, d(seq1[i], seq2[j]) is the Euclidean distance.
+ """
+ n, m = seq1.size(0), seq2.size(0)
+
+ # Cost matrix (pairwise distances between all elements)
+ cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype)
+ for i in range(n):
+ for j in range(m):
+ cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2)
+
+ # DP matrix for the Fréchet distance
+ ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype)
+ ca[0, 0] = cost[0, 0]
+
+ # Initialize first row
+ for i in range(1, n):
+ ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0])
+
+ # Initialize first column
+ for j in range(1, m):
+ ca[0, j] = torch.max(ca[0, j - 1], cost[0, j])
+
+ # Populate the DP table
+ for i in range(1, n):
+ for j in range(1, m):
+ ca[i, j] = torch.max(
+ cost[i, j],
+ torch.min(
+ torch.min(
+ ca[i - 1, j],
+ ca[i, j - 1],
+ ),
+ ca[i - 1, j - 1]
+ )
+ )
+
+ return ca[n - 1, m - 1]
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the average Fréchet distance over a batch of sequences.
+
+ pred, targ shapes: (batch_size, T, D)
+ """
+ # Ensure shapes match in batch dimension
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
+
+ distances = []
+ for b in range(pred.size(0)):
+ seq1 = pred[b]
+ seq2 = targ[b]
+ fd_val = self._frechet_distance(seq1, seq2)
+ distances.append(fd_val)
+
+ # Stack and take mean to get scalar loss
+ frechet_loss = torch.stack(distances).mean()
+ return frechet_loss
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
+ """
+ Returns a tuple: (loss, info_dict),
+ where loss is a scalar tensor and info_dict is a dictionary
+ of extra information (e.g., distance components).
+ """
+ loss = self._loss(pred, targ)
+ info = {
+ 'frechet': loss.item()
+ }
+ return loss, info
+
+class ChamferDistanceLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the symmetrical Chamfer distance between
+ two 2D tensors (N x D), where N is the number of points
+ and D is the feature dimension.
+
+ The Chamfer distance between two point sets A and B is often defined as:
+
+ d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂
+ + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂,
+
+ where ‖·‖₂ is the Euclidean distance.
+ """
+ # set1, set2 shapes: (num_points, feature_dim)
+ n, m = set1.size(0), set2.size(0)
+
+ cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype)
+ for i in range(n):
+ for j in range(m):
+ cost[i, j] = torch.norm(set1[i] - set2[j], p=2)
+
+ # For each point in set1, find distance to the closest point in set2
+ forward_min = cost.min(dim=1)[0] # shape: (n,)
+ forward_mean = forward_min.mean()
+
+ # For each point in set2, find distance to the closest point in set1
+ backward_min = cost.min(dim=0)[0] # shape: (m,)
+ backward_mean = backward_min.mean()
+
+ chamfer_dist = forward_mean + backward_mean
+ return chamfer_dist
+
+ def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the average Chamfer distance over a batch of point sets.
+
+ pred, targ shapes: (batch_size, N, D)
+ """
+ # Ensure shapes match in batch dimension
+ assert pred.size(0) == targ.size(0), "Batch sizes must match."
+
+ distances = []
+ for b in range(pred.size(0)):
+ set1 = pred[b]
+ set2 = targ[b]
+ distance_val = self._chamfer_distance(set1, set2)
+ distances.append(distance_val)
+
+ # Combine into a single scalar
+ chamfer_loss = torch.stack(distances).mean()
+ return chamfer_loss
+
+ def forward(self, pred: torch.Tensor, targ: torch.Tensor):
+ """
+ Returns a tuple: (loss, info_dict),
+ where 'loss' is a scalar tensor and 'info_dict' is a dictionary
+ of extra information (e.g., distance components).
+ """
+ loss = self._loss(pred, targ)
+ info = {
+ 'chamfer': loss.item()
+ }
+ return loss, info
+
+
+def slerp(q1, q2, t):
+ """Spherical linear interpolation between two quaternions."""
+ q1 = q1 / np.linalg.norm(q1)
+ q2 = q2 / np.linalg.norm(q2)
+ dot = np.dot(q1, q2)
+
+ if dot < 0.0:
+ q2 = -q2
+ dot = -dot
+ # If dot is very close to 1, use linear interpolation
+
+ if dot > 0.9995:
+ result = q1 + t * (q2 - q1)
+ result = result / np.linalg.norm(result)
+ return result
+
+ theta_0 = np.arccos(dot)
+ theta = theta_0 * t
+
+ q3 = q2 - q1 * dot
+ q3 = q3 / np.linalg.norm(q3)
+ return q1 * np.cos(theta) + q3 * np.sin(theta)
+
+def catmull_rom_spline_with_rotation(control_points, timepoints, horizon):
+ """Compute Catmull-Rom spline for both position and quaternion rotation."""
+ spline_points = []
+ # Extrapolate the initial points
+ if timepoints[0] != 0:
+ for t in range(timepoints[0]):
+ x = control_points[0][0]
+ y = control_points[0][1]
+ z = control_points[0][2]
+ q = control_points[0][3:7]
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
+
+ #Linear interpolate between 0th and 1th control points
+ for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1):
+ x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0])
+ y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1])
+ z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2])
+ q = slerp(control_points[0][3:7], control_points[1][3:7], t)
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
+
+
+ # Iterate over the control points
+ for i in range(1, len(control_points) - 2):
+ P0 = control_points[i-1][:3]
+ P1 = control_points[i][:3]
+ P2 = control_points[i+1][:3]
+ P3 = control_points[i+2][:3]
+ Q0 = control_points[i-1][3:7]
+ Q1 = control_points[i][3:7]
+ Q2 = control_points[i+1][3:7]
+ Q3 = control_points[i+2][3:7]
+
+ # Interpolate position (using Catmull-Rom spline)
+ for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)):
+ if idx == 0:
+ continue
+
+ x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t +
+ (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 +
+ (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3)
+ y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t +
+ (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 +
+ (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3)
+ z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t +
+ (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 +
+ (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3)
+ q = slerp(Q1, Q2, t)
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
+
+ #Linear interpolate between 2nd last and last control points
+ for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)):
+ if idx == 0:
+ continue
+ x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0])
+ y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1])
+ z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2])
+ q = slerp(control_points[-2][3:7], control_points[-1][3:7], t)
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
+
+ # Extrapolate the rest of the points
+ if timepoints[-1] != horizon:
+ for t in range(timepoints[-1] + 1, horizon):
+ x = control_points[-1][0]
+ y = control_points[-1][1]
+ z = control_points[-1][2]
+ q = control_points[-1][3:7]
+ spline_points.append(np.concatenate([np.array([x, y, z]), q]))
+
+ stacked_spline_points = np.stack(spline_points, axis=0)
+
+ if control_points.shape[1] != 7:
+ stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1)
+
+
+ return stacked_spline_points
+
+def catmull_rom_loss(trajectories, conditions, loss_fc):
+ '''
+ loss for catmull-rom interpolation
+ '''
+ batch_size, horizon, transition = trajectories.shape
+
+ # Extract known indices and values
+ known_indices = np.array(list(conditions.keys()), dtype=int)
+
+ # candidate_no x batch_size x dim
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
+ known_values = np.moveaxis(known_values, 0, 1)
+
+ # Sort the timepoints
+ sorted_indices = np.argsort(known_indices)
+ known_indices = known_indices[sorted_indices]
+ known_values = known_values[:, sorted_indices]
+ spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
+
+ # Convert to tensor and move to the same device as trajectories
+ spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
+ assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
+ return loss_fc(spline_points, trajectories)
+
Losses = {
'l1': WeightedL1,
'l2': WeightedL2,
'value_l1': ValueL1,
'value_l2': ValueL2,
+ 'geodesic_l2': GeodesicL2Loss,
+ 'rotation_translation': RotationTranslationLoss,
+ 'spline': SplineLoss,
}
diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py
index e0b9e5c..0f7854a 100644
--- a/diffuser/models/temporal.py
+++ b/diffuser/models/temporal.py
@@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module):
super().__init__()
self.blocks = nn.ModuleList([
- Conv1dBlock(inp_channels, out_channels, kernel_size),
- Conv1dBlock(out_channels, out_channels, kernel_size),
+ Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64),
+ Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64),
])
self.time_mlp = nn.Sequential(
nn.Mish(),
- nn.Linear(embed_dim, out_channels),
+ nn.Linear(embed_dim, out_channels).to(dtype=torch.float64),
Rearrange('batch t -> batch t 1'),
- )
+ ).to(dtype=torch.float64)
- self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
- if inp_channels != out_channels else nn.Identity()
+ self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \
+ if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64)
def forward(self, x, t):
'''
@@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module):
returns:
out : [ batch_size x out_channels x horizon ]
'''
- out = self.blocks[0](x) + self.time_mlp(t)
+
+ out = self.blocks[0](x) + self.time_mlp(t.double())
out = self.blocks[1](out)
return out + self.residual_conv(x)
@@ -49,11 +50,11 @@ class TemporalUnet(nn.Module):
transition_dim,
cond_dim,
dim=32,
- dim_mults=(1, 2, 4, 8),
+ dim_mults=(1, 2, 4),
):
super().__init__()
- dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
+ dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
print(f'[ models/temporal ] Channel dimensions: {in_out}')
@@ -100,7 +101,7 @@ class TemporalUnet(nn.Module):
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=5),
- nn.Conv1d(dim, transition_dim, 1),
+ nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64),
)
def forward(self, x, cond, time):
@@ -129,7 +130,6 @@ class TemporalUnet(nn.Module):
x = upsample(x)
x = self.final_conv(x)
-
x = einops.rearrange(x, 'b t h -> b h t')
return x
diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py
index c3a9d24..96a7093 100644
--- a/diffuser/utils/arrays.py
+++ b/diffuser/utils/arrays.py
@@ -54,7 +54,7 @@ def batchify(batch):
1) converting np arrays to torch tensors and
2) and ensuring that everything has a batch dimension
'''
- fn = lambda x: to_torch(x[None])
+ fn = lambda x: to_torch(x[None], dtype=torch.float64)
batched_vals = []
for field in batch._fields:
diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py
index 6cc9db9..039eb64 100644
--- a/diffuser/utils/serialization.py
+++ b/diffuser/utils/serialization.py
@@ -19,7 +19,7 @@ def mkdir(savepath):
return False
def get_latest_epoch(loadpath):
- states = glob.glob1(os.path.join(*loadpath), 'state_*')
+ states = glob.glob1(os.path.join(loadpath), 'state_*')
latest_epoch = -1
for state in states:
epoch = int(state.replace('state_', '').replace('.pt', ''))
diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py
index be3556e..c21e0f0 100644
--- a/diffuser/utils/training.py
+++ b/diffuser/utils/training.py
@@ -4,16 +4,24 @@ import numpy as np
import torch
import einops
import pdb
+from tqdm import tqdm
+import wandb
+from pytorch3d.transforms import axis_angle_to_quaternion
from .arrays import batch_to_device, to_np, to_device, apply_dict
from .timer import Timer
from .cloud import sync_logs
+from ..models.helpers import catmull_rom_spline_with_rotation
def cycle(dl):
while True:
for data in dl:
yield data
+def assert_no_nan_weights(model):
+ for name, param in model.named_parameters():
+ assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}"
+
class EMA():
'''
empirical moving average
@@ -71,13 +79,35 @@ class Trainer(object):
self.gradient_accumulate_every = gradient_accumulate_every
self.dataset = dataset
- self.dataloader = cycle(torch.utils.data.DataLoader(
- self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True
+ dataset_size = len(self.dataset)
+
+ # Read the indices from the .txt file
+ with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f:
+ self.train_indices = f.read()
+ self.train_indices = [int(i) for i in self.train_indices.split('\n') if i]
+
+ with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f:
+ self.val_indices = f.read()
+ self.val_indices = [int(i) for i in self.val_indices.split('\n') if i]
+
+
+ self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices)
+ self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices)
+ self.train_dataloader = cycle(torch.utils.data.DataLoader(
+ self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
+ ))
+
+ self.val_dataloader = cycle(torch.utils.data.DataLoader(
+ self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False
))
+
self.dataloader_vis = cycle(torch.utils.data.DataLoader(
self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
))
self.renderer = renderer
+
+
+
self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr)
self.logdir = results_folder
@@ -88,6 +118,8 @@ class Trainer(object):
self.reset_parameters()
self.step = 0
+ self.log_to_wandb = False
+
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
@@ -102,36 +134,129 @@ class Trainer(object):
#-----------------------------------------------------------------------------#
def train(self, n_train_steps):
-
+ # Save the indices as .txt files
+ with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f:
+ for idx in self.train_indices:
+ f.write(f"{idx}\n")
+ with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f:
+ for idx in self.val_indices:
+ f.write(f"{idx}\n")
+
timer = Timer()
- for step in range(n_train_steps):
+ torch.autograd.set_detect_anomaly(True)
+
+ # Setup wandb
+ if self.log_to_wandb:
+ wandb.init(
+ project='trajectory-generation',
+ config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every},
+ )
+
+ for step in tqdm(range(n_train_steps)):
+
+ mean_train_loss = 0.0
for i in range(self.gradient_accumulate_every):
- batch = next(self.dataloader)
+ batch = next(self.train_dataloader)
batch = batch_to_device(batch)
-
- loss, infos = self.model.loss(*batch)
+
+ loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions)
loss = loss / self.gradient_accumulate_every
+ mean_train_loss += loss.item()
loss.backward()
+ if self.log_to_wandb:
+ wandb.log({
+ 'step': self.step,
+ 'train/loss': mean_train_loss
+ })
+
+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
+
self.optimizer.step()
self.optimizer.zero_grad()
+ assert_no_nan_weights(self.model)
+
if self.step % self.update_ema_every == 0:
self.step_ema()
if self.step % self.save_freq == 0:
- label = self.step // self.label_freq * self.label_freq
+ label = self.step
+ print(f'Saving model at step {self.step}...')
self.save(label)
if self.step % self.log_freq == 0:
- infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
- print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}')
+ val_losses = []
+ lin_int_losses = []
+
+ val_infos_list = []
+ lin_int_infos_list = []
+
+ catmull_losses = []
+ catmull_infos_list = []
+
+ for _ in range(len(self.val_indices)):
+ val_batch = next(self.val_dataloader)
+ val_batch = batch_to_device(val_batch)
+
+ traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1])
+ val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions)
+
+ val_losses.append(val_loss.item())
+ val_infos_list.append({key: val for key, val in val_infos.items()})
+
+
+ (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss(
+ val_batch.trajectories, val_batch.conditions, self.model.loss_fn
+ )
+ lin_int_losses.append(lin_int_loss.item())
+ lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()})
+
+ (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss(
+ val_batch.trajectories, val_batch.conditions, self.model.loss_fn
+ )
+
+ catmull_losses.append(catmull_loss.item())
+ catmull_infos_list.append(catmull_infos)
+
+ avg_val_loss = np.mean(val_losses)
+ avg_lin_int_loss = np.mean(lin_int_losses)
+
+ val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()}
+ lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()}
- if self.step == 0 and self.sample_freq:
- self.render_reference(self.n_reference)
+ avg_catmull_loss = np.mean(catmull_losses)
+ catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()}
- if self.sample_freq and self.step % self.sample_freq == 0:
- self.render_samples(n_samples=self.n_samples)
+ val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()])
+ lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()])
+ catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()])
+
+
+ infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
+ print("Learning Rate: ", self.optimizer.param_groups[0]['lr'])
+ print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}')
+ print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}')
+ print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}')
+ print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}')
+ print()
+
+ if self.log_to_wandb:
+ wandb.log({
+ 'step': self.step,
+ 'val/loss': avg_val_loss,
+ 'val/linear_interp/loss': avg_lin_int_loss,
+ 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'],
+ 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'],
+ 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'],
+ 'val/catmull_rom/loss': avg_catmull_loss,
+ 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'],
+ 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'],
+ 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'],
+ 'val/quaternion dist.': val_infos['quat. dist.'],
+ 'val/euclidean dist.': val_infos['trans. error'],
+ 'val/geodesic loss': val_infos['geodesic dist.'],
+ })
self.step += 1
@@ -186,15 +311,6 @@ class Trainer(object):
normed_observations = trajectories[:, :, self.dataset.action_dim:]
observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
- # from diffusion.datasets.preprocessing import blocks_cumsum_quat
- # # observations = conditions + blocks_cumsum_quat(deltas)
- # observations = conditions + deltas.cumsum(axis=1)
-
- #### @TODO: remove block-stacking specific stuff
- # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
- # observations = blocks_add_kuka(observations)
- ####
-
savepath = os.path.join(self.logdir, f'_sample-reference.png')
self.renderer.composite(savepath, observations)
@@ -225,9 +341,6 @@ class Trainer(object):
# [ 1 x 1 x observation_dim ]
normed_conditions = to_np(batch.conditions[0])[:,None]
- # from diffusion.datasets.preprocessing import blocks_cumsum_quat
- # observations = conditions + blocks_cumsum_quat(deltas)
- # observations = conditions + deltas.cumsum(axis=1)
## [ n_samples x (horizon + 1) x observation_dim ]
normed_observations = np.concatenate([
@@ -238,10 +351,70 @@ class Trainer(object):
## [ n_samples x (horizon + 1) x observation_dim ]
observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
- #### @TODO: remove block-stacking specific stuff
- # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
- # observations = blocks_add_kuka(observations)
- ####
-
savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png')
self.renderer.composite(savepath, observations)
+
+ def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
+ batch_size, horizon, transition = trajectories.shape
+
+ # Extract known indices and values
+ known_indices = np.array(list(conditions.keys()), dtype=int)
+ # candidate_no x batch_size x dim
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
+ known_values = np.moveaxis(known_values, 0, 1)
+
+ # Create time steps for interpolation
+ time_steps = np.linspace(0, horizon, num=horizon)
+
+ # Perform interpolation across all dimensions at once
+ linear_int_arr = np.array([[
+ np.interp(time_steps, known_indices, known_values[b, :, dim])
+ for dim in range(transition)]
+ for b in range(batch_size)]
+ ).T # Transpose to match shape (horizon, transition)
+
+ # Convert to tensor and move to the same device as trajectories
+ linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1])
+ linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device)
+
+ return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor
+
+
+ def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None):
+ '''
+ loss for catmull-rom interpolation
+ '''
+
+ batch_size, horizon, transition = trajectories.shape
+
+ # Extract known indices and values
+ known_indices = np.array(list(conditions.keys()), dtype=int)
+ # candidate_no x batch_size x dim
+ known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0)
+ known_values = np.moveaxis(known_values, 0, 1)
+
+ # Sort the timepoints
+ sorted_indices = np.argsort(known_indices)
+ known_indices = known_indices[sorted_indices]
+ known_values = known_values[:, sorted_indices]
+
+ spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)])
+
+ # Convert to tensor and move to the same device as trajectories
+ spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device)
+
+ assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}"
+
+ return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scripts/train.py b/scripts/train.py
index 2c5f299..6728d6f 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -108,6 +108,7 @@ utils.report_parameters(model)
print('Testing forward...', end=' ', flush=True)
batch = utils.batchify(dataset[0])
+
loss, _ = diffusion.loss(*batch)
loss.backward()
print('✓')