| diff --git a/config/locomotion.py b/config/locomotion.py | |
| deleted file mode 100644 | |
| index 4410bb1..0000000 | |
| --- a/config/locomotion.py | |
| +++ /dev/null | |
| @@ -1,70 +0,0 @@ | |
| -import socket | |
| - | |
| -from diffuser.utils import watch | |
| - | |
| -#------------------------ base ------------------------# | |
| - | |
| -## automatically make experiment names for planning | |
| -## by labelling folders with these args | |
| - | |
| -diffusion_args_to_watch = [ | |
| - ('prefix', ''), | |
| - ('horizon', 'H'), | |
| - ('n_diffusion_steps', 'T'), | |
| -] | |
| - | |
| -base = { | |
| - 'diffusion': { | |
| - ## model | |
| - 'model': 'models.TemporalUnet', | |
| - 'diffusion': 'models.GaussianDiffusion', | |
| - 'horizon': 32, | |
| - 'n_diffusion_steps': 100, | |
| - 'action_weight': 10, | |
| - 'loss_weights': None, | |
| - 'loss_discount': 1, | |
| - 'predict_epsilon': False, | |
| - 'dim_mults': (1, 4, 8), | |
| - 'renderer': 'utils.MuJoCoRenderer', | |
| - | |
| - ## dataset | |
| - 'loader': 'datasets.SequenceDataset', | |
| - 'normalizer': 'LimitsNormalizer', | |
| - 'preprocess_fns': [], | |
| - 'clip_denoised': True, | |
| - 'use_padding': True, | |
| - 'max_path_length': 1000, | |
| - | |
| - ## serialization | |
| - 'logbase': 'logs', | |
| - 'prefix': 'diffusion/', | |
| - 'exp_name': watch(diffusion_args_to_watch), | |
| - | |
| - ## training | |
| - 'n_steps_per_epoch': 10000, | |
| - 'loss_type': 'l2', | |
| - 'n_train_steps': 1e6, | |
| - 'batch_size': 32, | |
| - 'learning_rate': 2e-4, | |
| - 'gradient_accumulate_every': 2, | |
| - 'ema_decay': 0.995, | |
| - 'save_freq': 1000, | |
| - 'sample_freq': 1000, | |
| - 'n_saves': 5, | |
| - 'save_parallel': False, | |
| - 'n_reference': 8, | |
| - 'n_samples': 2, | |
| - 'bucket': None, | |
| - 'device': 'cuda', | |
| - }, | |
| -} | |
| - | |
| -#------------------------ overrides ------------------------# | |
| - | |
| -## put environment-specific overrides here | |
| - | |
| -halfcheetah_medium_expert_v2 = { | |
| - 'diffusion': { | |
| - 'horizon': 16, | |
| - }, | |
| -} | |
| diff --git a/config/maze2d.py b/config/maze2d.py | |
| index a06ac7f..0a8d22a 100644 | |
| --- a/config/maze2d.py | |
| +++ b/config/maze2d.py | |
| @@ -34,11 +34,11 @@ base = { | |
| 'model': 'models.TemporalUnet', | |
| 'diffusion': 'models.GaussianDiffusion', | |
| 'horizon': 256, | |
| - 'n_diffusion_steps': 256, | |
| + 'n_diffusion_steps': 512, | |
| 'action_weight': 1, | |
| 'loss_weights': None, | |
| 'loss_discount': 1, | |
| - 'predict_epsilon': False, | |
| + 'predict_epsilon': True, | |
| 'dim_mults': (1, 4, 8), | |
| 'renderer': 'utils.Maze2dRenderer', | |
| @@ -57,14 +57,14 @@ base = { | |
| 'exp_name': watch(diffusion_args_to_watch), | |
| ## training | |
| - 'n_steps_per_epoch': 10000, | |
| - 'loss_type': 'l2', | |
| - 'n_train_steps': 2e6, | |
| - 'batch_size': 32, | |
| - 'learning_rate': 2e-4, | |
| - 'gradient_accumulate_every': 2, | |
| + 'n_steps_per_epoch': 60000, | |
| + 'loss_type': 'spline', | |
| + 'n_train_steps': 6e4, | |
| + 'batch_size': 1, | |
| + 'learning_rate': 5e-6, | |
| + 'gradient_accumulate_every': 8, | |
| 'ema_decay': 0.995, | |
| - 'save_freq': 1000, | |
| + 'save_freq': 2000, | |
| 'sample_freq': 1000, | |
| 'n_saves': 50, | |
| 'save_parallel': False, | |
| @@ -89,7 +89,6 @@ base = { | |
| 'prefix': 'plans/release', | |
| 'exp_name': watch(plan_args_to_watch), | |
| 'suffix': '0', | |
| - | |
| 'conditional': False, | |
| ## loading | |
| @@ -122,10 +121,10 @@ maze2d_umaze_v1 = { | |
| maze2d_large_v1 = { | |
| 'diffusion': { | |
| 'horizon': 384, | |
| - 'n_diffusion_steps': 256, | |
| + 'n_diffusion_steps': 16, | |
| }, | |
| 'plan': { | |
| 'horizon': 384, | |
| - 'n_diffusion_steps': 256, | |
| + 'n_diffusion_steps': 16, | |
| }, | |
| } | |
| diff --git a/diffuser/datasets/buffer.py b/diffuser/datasets/buffer.py | |
| index 1ad2106..5991f01 100644 | |
| --- a/diffuser/datasets/buffer.py | |
| +++ b/diffuser/datasets/buffer.py | |
| @@ -9,7 +9,7 @@ class ReplayBuffer: | |
| def __init__(self, max_n_episodes, max_path_length, termination_penalty): | |
| self._dict = { | |
| - 'path_lengths': np.zeros(max_n_episodes, dtype=np.int), | |
| + 'path_lengths': np.zeros(max_n_episodes, dtype=np.int_), | |
| } | |
| self._count = 0 | |
| self.max_n_episodes = max_n_episodes | |
| diff --git a/diffuser/datasets/sequence.py b/diffuser/datasets/sequence.py | |
| index 356c540..73c1b04 100644 | |
| --- a/diffuser/datasets/sequence.py | |
| +++ b/diffuser/datasets/sequence.py | |
| @@ -83,6 +83,7 @@ class SequenceDataset(torch.utils.data.Dataset): | |
| actions = self.fields.normed_actions[path_ind, start:end] | |
| conditions = self.get_conditions(observations) | |
| + | |
| trajectories = np.concatenate([actions, observations], axis=-1) | |
| batch = Batch(trajectories, conditions) | |
| return batch | |
| diff --git a/diffuser/models/diffusion.py b/diffuser/models/diffusion.py | |
| index fae4cfd..461680a 100644 | |
| --- a/diffuser/models/diffusion.py | |
| +++ b/diffuser/models/diffusion.py | |
| @@ -2,6 +2,7 @@ import numpy as np | |
| import torch | |
| from torch import nn | |
| import pdb | |
| +import matplotlib.pyplot as plt | |
| import diffuser.utils as utils | |
| from .helpers import ( | |
| @@ -9,6 +10,7 @@ from .helpers import ( | |
| extract, | |
| apply_conditioning, | |
| Losses, | |
| + catmull_rom_spline_with_rotation, | |
| ) | |
| class GaussianDiffusion(nn.Module): | |
| @@ -26,6 +28,7 @@ class GaussianDiffusion(nn.Module): | |
| betas = cosine_beta_schedule(n_timesteps) | |
| alphas = 1. - betas | |
| alphas_cumprod = torch.cumprod(alphas, axis=0) | |
| + print(f"Alphas Cumprod: {alphas_cumprod}") | |
| alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) | |
| self.n_timesteps = int(n_timesteps) | |
| @@ -73,7 +76,7 @@ class GaussianDiffusion(nn.Module): | |
| ''' | |
| self.action_weight = action_weight | |
| - dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) | |
| + dim_weights = torch.ones(self.transition_dim, dtype=torch.float64) | |
| ## set loss coefficients for dimensions of observation | |
| if weights_dict is None: weights_dict = {} | |
| @@ -97,18 +100,16 @@ class GaussianDiffusion(nn.Module): | |
| otherwise, model predicts x0 directly | |
| ''' | |
| if self.predict_epsilon: | |
| - return ( | |
| - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - | |
| - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise | |
| - ) | |
| + return noise | |
| else: | |
| return noise | |
| def q_posterior(self, x_start, x_t, t): | |
| posterior_mean = ( | |
| extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + | |
| - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t | |
| + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t[:, :, self.action_dim:] | |
| ) | |
| + | |
| posterior_variance = extract(self.posterior_variance, t, x_t.shape) | |
| posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) | |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped | |
| @@ -129,7 +130,7 @@ class GaussianDiffusion(nn.Module): | |
| def p_sample(self, x, cond, t): | |
| b, *_, device = *x.shape, x.device | |
| model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t) | |
| - noise = torch.randn_like(x) | |
| + noise = torch.randn_like(x[:, :, self.action_dim:]) | |
| # no noise when t == 0 | |
| nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) | |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise | |
| @@ -139,22 +140,59 @@ class GaussianDiffusion(nn.Module): | |
| device = self.betas.device | |
| batch_size = shape[0] | |
| - x = torch.randn(shape, device=device) | |
| - x = apply_conditioning(x, cond, self.action_dim) | |
| + # x = torch.randn(shape, device=device, dtype=torch.float64) | |
| + # Extract known indices and values | |
| + known_indices = np.array(list(cond.keys()), dtype=int) | |
| + | |
| + # candidate_no x batch_size x dim | |
| + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) | |
| + known_values = np.moveaxis(known_values, 0, 1) | |
| + | |
| + # Sort the timepoints | |
| + sorted_indices = np.argsort(known_indices) | |
| + known_indices = known_indices[sorted_indices] | |
| + known_values = known_values[:, sorted_indices] | |
| + | |
| + # Build the structured spline guess | |
| + catmull_spline_trajectory = np.array([ | |
| + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, shape[1]) | |
| + for b in range(batch_size) | |
| + ]) | |
| + catmull_spline_trajectory = torch.tensor( | |
| + catmull_spline_trajectory, | |
| + dtype=torch.float64, | |
| + device=device | |
| + ) | |
| + | |
| + | |
| + if self.predict_epsilon: | |
| + x = torch.randn((shape[0], shape[1], self.observation_dim), device=device, dtype=torch.float64) | |
| + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} | |
| + is_cond = torch.zeros((shape[0], shape[1], 1), device=device, dtype=torch.float64) | |
| + is_cond[:, known_indices, :] = 1.0 | |
| if return_diffusion: diffusion = [x] | |
| - progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() | |
| + # progress = utils.Progress(self.n_timesteps) if verbose else utils.Silent() | |
| for i in reversed(range(0, self.n_timesteps)): | |
| + if self.predict_epsilon: | |
| + x = torch.cat([catmull_spline_trajectory, is_cond, x], dim=-1) | |
| + | |
| timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long) | |
| - x = self.p_sample(x, cond, timesteps) | |
| - x = apply_conditioning(x, cond, self.action_dim) | |
| + x = self.p_sample(x, cond_residual, timesteps) | |
| + | |
| + x = apply_conditioning(x, cond_residual, 0) | |
| - progress.update({'t': i}) | |
| + if return_diffusion: diffusion.append(x) | |
| - if return_diffusion: diffusion.append(x) | |
| + x = catmull_spline_trajectory + x | |
| - progress.close() | |
| + | |
| + | |
| + # Normalize the quaternions | |
| + # x[:, :, 3:7] = x[:, :, 3:7] / torch.norm(x[:, :, 3:7], dim=-1, keepdim=True) | |
| + | |
| + # progress.close() | |
| if return_diffusion: | |
| return x, torch.stack(diffusion, dim=1) | |
| @@ -167,7 +205,7 @@ class GaussianDiffusion(nn.Module): | |
| conditions : [ (time, state), ... ] | |
| ''' | |
| device = self.betas.device | |
| - batch_size = len(cond[0]) | |
| + batch_size = len(next(iter(cond.values()))) | |
| horizon = horizon or self.horizon | |
| shape = (batch_size, horizon, self.transition_dim) | |
| @@ -175,38 +213,106 @@ class GaussianDiffusion(nn.Module): | |
| #------------------------------------------ training ------------------------------------------# | |
| - def q_sample(self, x_start, t, noise=None): | |
| + def q_sample(self, x_start, t, spline=None, noise=None): | |
| + x_start_noise = x_start[:, : , :-1] | |
| + x_start_is_cond = x_start[:, :, [-1]] | |
| + | |
| + if spline is None: | |
| + spline = torch.randn_like(x_start_noise) | |
| if noise is None: | |
| - noise = torch.randn_like(x_start) | |
| + noise = torch.randn_like(x_start_noise) | |
| - sample = ( | |
| - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + | |
| - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise | |
| - ) | |
| + alpha = extract(self.sqrt_alphas_cumprod, t, x_start.shape) | |
| + oneminusalpha = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| + | |
| + # Weighted combination of x_0 and the spline | |
| + out = alpha * x_start_noise + oneminusalpha * noise | |
| + | |
| + # Concatenate the binary feature and the spline as the conditioning | |
| + out = torch.cat([spline, x_start_is_cond, out], dim=-1) | |
| - return sample | |
| + return out | |
| def p_losses(self, x_start, cond, t): | |
| - noise = torch.randn_like(x_start) | |
| + batch_size, horizon, _ = x_start.shape | |
| + # Extract known indices and values | |
| + known_indices = np.array(list(cond.keys()), dtype=int) | |
| + | |
| + # candidate_no x batch_size x dim | |
| + known_values = np.stack([c.cpu().numpy() for c in cond.values()], axis=0) | |
| + known_values = np.moveaxis(known_values, 0, 1) | |
| + | |
| + # Sort the timepoints | |
| + sorted_indices = np.argsort(known_indices) | |
| + known_indices = known_indices[sorted_indices] | |
| + known_values = known_values[:, sorted_indices] | |
| + | |
| + # Build your structured guess | |
| + catmull_spline_trajectory = np.array([ | |
| + catmull_rom_spline_with_rotation(known_values[b, :, :-1], known_indices, horizon) | |
| + for b in range(batch_size) | |
| + ]) | |
| + catmull_spline_trajectory = torch.tensor( | |
| + catmull_spline_trajectory, | |
| + dtype=torch.float64, | |
| + device=x_start.device | |
| + ) | |
| - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| - x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) | |
| + # Plot the quaternions | |
| + # plt.plot(x_start[0, :, 3].cpu().numpy()) | |
| + # plt.plot(catmull_spline_trajectory[0, :, 3].cpu().numpy()) | |
| + # plt.legend(["x_start", "catmull_spline"]) | |
| + # plt.show() | |
| + # raise Exception | |
| - x_recon = self.model(x_noisy, cond, t) | |
| - x_recon = apply_conditioning(x_recon, cond, self.action_dim) | |
| - assert noise.shape == x_recon.shape | |
| + if not self.predict_epsilon: | |
| + # Forward diffuse with the structured trajectory | |
| + x_noisy = self.q_sample( | |
| + x_start, | |
| + t, | |
| + spline=catmull_spline_trajectory, | |
| + ) | |
| + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) | |
| - if self.predict_epsilon: | |
| - loss, info = self.loss_fn(x_recon, noise) | |
| + # Reverse pass guess | |
| + x_recon = self.model(x_noisy, cond, t) | |
| + x_recon = apply_conditioning(x_recon, cond, self.action_dim) | |
| + | |
| + # Then x_recon is the predicted x_0, compare to the true x_0 | |
| + loss, info = self.loss_fn(x_recon, x_start, cond) | |
| else: | |
| - loss, info = self.loss_fn(x_recon, x_start) | |
| + residual = x_start.clone() | |
| + | |
| + residual[:, :, :-1] -= catmull_spline_trajectory | |
| + | |
| + | |
| + cond_residual = {k: torch.zeros_like(v)[:, :-1] for k, v in cond.items()} | |
| + | |
| + x_noisy = self.q_sample( | |
| + residual, | |
| + t, | |
| + spline=catmull_spline_trajectory, | |
| + ) | |
| + x_noisy = apply_conditioning(x_noisy, cond_residual, self.action_dim) | |
| + | |
| + # Reverse pass guess | |
| + x_recon = self.model(x_noisy, cond, t) | |
| + x_recon = apply_conditioning(x_recon, cond_residual, 0) | |
| + | |
| + x_recon = x_recon + catmull_spline_trajectory | |
| + | |
| + loss, info = self.loss_fn(x_recon, x_start[:, :, :-1], cond) | |
| return loss, info | |
| def loss(self, x, cond): | |
| batch_size = len(x) | |
| t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() | |
| + # t = torch.randint(1, 2, (batch_size,), device=x.device).long() | |
| + # x = x.double() | |
| + # cond = {k: v.double() for k, v in cond.items()} | |
| + # print(f"Time: {t.item()}") | |
| return self.p_losses(x, cond, t) | |
| def forward(self, cond, *args, **kwargs): | |
| diff --git a/diffuser/models/helpers.py b/diffuser/models/helpers.py | |
| index d39f35d..9f43ef8 100644 | |
| --- a/diffuser/models/helpers.py | |
| +++ b/diffuser/models/helpers.py | |
| @@ -1,11 +1,11 @@ | |
| import math | |
| +import json | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| -import einops | |
| from einops.layers.torch import Rearrange | |
| -import pdb | |
| +from pytorch3d.transforms import quaternion_to_matrix, quaternion_to_axis_angle | |
| import diffuser.utils as utils | |
| @@ -30,7 +30,7 @@ class SinusoidalPosEmb(nn.Module): | |
| class Downsample1d(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) | |
| + self.conv = nn.Conv1d(dim, dim, 3, 2, 1).to(torch.float64) | |
| def forward(self, x): | |
| return self.conv(x) | |
| @@ -38,7 +38,7 @@ class Downsample1d(nn.Module): | |
| class Upsample1d(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) | |
| + self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1).to(torch.float64) | |
| def forward(self, x): | |
| return self.conv(x) | |
| @@ -52,9 +52,9 @@ class Conv1dBlock(nn.Module): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), | |
| + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2).to(torch.float64), | |
| Rearrange('batch channels horizon -> batch channels 1 horizon'), | |
| - nn.GroupNorm(n_groups, out_channels), | |
| + nn.GroupNorm(n_groups, out_channels).to(torch.float64), | |
| Rearrange('batch channels 1 horizon -> batch channels horizon'), | |
| nn.Mish(), | |
| ) | |
| @@ -72,7 +72,7 @@ def extract(a, t, x_shape): | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| -def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): | |
| +def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float64): | |
| """ | |
| cosine schedule | |
| as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
| @@ -157,9 +157,979 @@ class ValueL2(ValueLoss): | |
| def _loss(self, pred, targ): | |
| return F.mse_loss(pred, targ, reduction='none') | |
| +class GeodesicL2Loss(nn.Module): | |
| + def __init__(self, *args): | |
| + super().__init__() | |
| + pass | |
| + | |
| + def _loss(self, pred, targ): | |
| + # Compute L2 loss for the first three dimensions | |
| + l2_loss = F.mse_loss(pred[..., :3], targ[..., :3], reduction='mean') | |
| + | |
| + # Normalize to unit quaternions for the last four dimensions | |
| + pred_quat = pred[..., 3:] / pred[..., 3:].norm(dim=-1, keepdim=True) | |
| + targ_quat = targ[..., 3:] / targ[..., 3:].norm(dim=-1, keepdim=True) | |
| + | |
| + assert not torch.isnan(pred_quat).any(), "Pred Quat has NaNs" | |
| + assert not torch.isnan(targ_quat).any(), "Targ Quat has NaNs" | |
| + | |
| + # Compute dot product for the quaternions | |
| + dot_product = torch.sum(pred_quat * targ_quat, dim=-1) | |
| + dot_product = torch.clamp(torch.abs(dot_product), -1.0, 1.0) | |
| + | |
| + # Compute geodesic loss for the quaternions | |
| + geodesic_loss = 2 * torch.acos(dot_product).mean() | |
| + | |
| + assert not torch.isnan(geodesic_loss).any(), "Geodesic Loss has NaNs" | |
| + assert not torch.isnan(l2_loss).any(), "L2 Loss has NaNs" | |
| + | |
| + return l2_loss + geodesic_loss, l2_loss, geodesic_loss | |
| + | |
| + def forward(self, pred, targ): | |
| + loss, l2, geodesic = self._loss(pred, targ) | |
| + | |
| + info = { | |
| + 'l2': l2.item(), | |
| + 'geodesic': geodesic.item(), | |
| + } | |
| + | |
| + return loss, info | |
| + | |
| +class RotationTranslationLoss(nn.Module): | |
| + def __init__(self, *args): | |
| + super().__init__() | |
| + pass | |
| + | |
| + def _loss(self, pred, targ, cond=None): | |
| + | |
| + # Make sure the dtype is float64 | |
| + pred = pred.to(torch.float64) | |
| + targ = targ.to(torch.float64) | |
| + | |
| + eps = 1e-8 | |
| + | |
| + pred_trans = pred[..., :3] | |
| + pred_quat = pred[..., 3:7] | |
| + targ_trans = targ[..., :3] | |
| + targ_quat = targ[..., 3:7] | |
| + | |
| + l2_loss = F.mse_loss(pred_trans, targ_trans, reduction='mean') | |
| + | |
| + # Calculate the geodesic loss | |
| + pred_n = pred_quat.norm(dim=-1, keepdim=True).clamp(min=eps) | |
| + targ_n = targ_quat.norm(dim=-1, keepdim=True).clamp(min=eps) | |
| + | |
| + pred_quat_norm = pred_quat / pred_n | |
| + targ_quat_norm = targ_quat / targ_n | |
| + | |
| + | |
| + dot_product = torch.sum(pred_quat_norm * targ_quat_norm, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) | |
| + quaternion_dist = 1 - (dot_product ** 2).mean() | |
| + | |
| + # Calculate the rotation error | |
| + pred_rot = quaternion_to_matrix(pred_quat_norm).reshape(-1, 3, 3) | |
| + targ_rot = quaternion_to_matrix(targ_quat_norm).reshape(-1, 3, 3) | |
| + | |
| + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) | |
| + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) | |
| + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) | |
| + geodesic_loss = torch.acos(trace).mean() | |
| + | |
| + # Add a smoothness and acceleration term to the positions and quaternions | |
| + alpha = 1.0 | |
| + smoothness_loss = F.mse_loss(pred[:, 1:, :7].reshape(-1, 7), pred[:, :-1, :7].reshape(-1, 7), reduction='mean') | |
| + acceleration_loss = F.mse_loss(pred[:, 2:, :7].reshape(-1, 7), 2 * pred[:, 1:-1, :7].reshape(-1, 7) - pred[:, :-2, :7].reshape(-1, 7), reduction='mean') | |
| + | |
| + l2_multiplier = 10.0 | |
| + | |
| + loss = l2_multiplier * l2_loss + quaternion_dist + geodesic_loss + alpha * (smoothness_loss + acceleration_loss) | |
| + | |
| + dtw = DynamicTimeWarpingLoss() | |
| + dtw_loss, _ = dtw.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) | |
| + | |
| + hausdorff = HausdorffDistanceLoss() | |
| + hausdorff_loss, _ = hausdorff.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) | |
| + | |
| + frec = FrechetDistanceLoss() | |
| + frechet_loss, _ = frec.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) | |
| + | |
| + chamfer = ChamferDistanceLoss() | |
| + chamfer_loss, _ = chamfer.forward(pred_trans.reshape(-1, 3), targ_trans.reshape(-1, 3)) | |
| + | |
| + return loss, l2_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss | |
| + | |
| + | |
| + def forward(self, pred, targ, cond=None): | |
| + loss, err_t, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond) | |
| + | |
| + info = { | |
| + 'rot. error': err_r.item(), | |
| + 'geodesic error': err_geo.item(), | |
| + 'trans. error': err_t.item(), | |
| + 'dtw': err_dtw.item(), | |
| + 'hausdorff': err_hausdorff.item(), | |
| + 'frechet': err_frechet.item(), | |
| + 'chamfer': err_chamfer.item(), | |
| + } | |
| + | |
| + return loss, info | |
| + | |
| +class SplineLoss(nn.Module): | |
| + def __init__(self, *args): | |
| + super().__init__() | |
| + self.scales = json.load(open('scene_scale.json')) | |
| + | |
| + def compute_spline_coeffs(self, trans): | |
| + p0 = trans[:, :-3, :] | |
| + p1 = trans[:, 1:-2, :] | |
| + p2 = trans[:, 2:-1, :] | |
| + p3 = trans[:, 3:, :] | |
| + | |
| + # Tangent approximations | |
| + m1 = 0.5 * (-p0 + p2) | |
| + m2 = 0.5 * (-p1 + p3) | |
| + | |
| + # Cubic spline coefficients for each dimension | |
| + a = (2 * p1 - 2 * p2 + m1 + m2) | |
| + b = (-3 * p1 + 3 * p2 - 2 * m1 - m2) | |
| + c = (m1) | |
| + d = (p1) | |
| + | |
| + return torch.stack([a, b, c, d], dim=-1) | |
| + | |
| + def q_normalize(self, q): | |
| + return q / q.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-12) | |
| + | |
| + def q_conjugate(self, q): | |
| + w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3] | |
| + return torch.stack([w, -x, -y, -z], dim=-1) | |
| + | |
| + def q_multiply(self, q1, q2): | |
| + """ | |
| + q1*q2. | |
| + """ | |
| + w1, x1, y1, z1 = q1.unbind(-1) | |
| + w2, x2, y2, z2 = q2.unbind(-1) | |
| + w = w1*w2 - x1*x2 - y1*y2 - z1*z2 | |
| + x = w1*x2 + x1*w2 + y1*z2 - z1*y2 | |
| + y = w1*y2 - x1*z2 + y1*w2 + z1*x2 | |
| + z = w1*z2 + x1*y2 - y1*x2 + z1*w2 | |
| + return torch.stack([w, x, y, z], dim=-1) | |
| + | |
| + def q_inverse(self, q): | |
| + return self.q_conjugate(self.q_normalize(q)) | |
| + | |
| + def q_log(self, q): | |
| + """ | |
| + Quaternion logarithm for a unit quaternion | |
| + Only returns the imaginary part | |
| + """ | |
| + q = self.q_normalize(q) | |
| + w = q[..., 0] | |
| + xyz = q[..., 1:] # shape [..., 3] | |
| + mag_v = xyz.norm(p=2, dim=-1) | |
| + eps = 1e-12 | |
| + angle = torch.acos(w.clamp(-1.0 + eps, 1.0 - eps)) | |
| + | |
| + # We do a safe-guard against zero for sin(angle) | |
| + small_mask = (mag_v < 1e-12) | (angle < 1e-12) | |
| + # Where small_mask is True => near identity => log(q) ~ 0 | |
| + log_val = torch.zeros_like(xyz) | |
| + | |
| + # Normal case | |
| + scale = angle / mag_v.clamp(min=1e-12) | |
| + normal_case = scale.unsqueeze(-1) * xyz | |
| + | |
| + log_val = torch.where( | |
| + small_mask.unsqueeze(-1), | |
| + torch.zeros_like(xyz), | |
| + normal_case | |
| + ) | |
| + return log_val | |
| + | |
| + def q_exp(self, v): | |
| + """ | |
| + Quaternion exponential | |
| + """ | |
| + norm_v = v.norm(p=2, dim=-1) | |
| + small_mask = norm_v < 1e-12 | |
| + | |
| + w = torch.cos(norm_v) | |
| + sin_v = torch.sin(norm_v) | |
| + scale = torch.where( | |
| + small_mask, | |
| + torch.zeros_like(norm_v), # if zero, sin(0)/0 => 0 | |
| + sin_v / norm_v.clamp(min=1e-12) | |
| + ) | |
| + xyz = scale.unsqueeze(-1) * v | |
| + | |
| + # For small angles, we approximate cos(norm_v) ~ 1, sin(norm_v)/norm_v ~ 1 | |
| + w = torch.where( | |
| + small_mask, | |
| + torch.ones_like(w), | |
| + w | |
| + ) | |
| + return torch.cat([w.unsqueeze(-1), xyz], dim=-1) | |
| + | |
| + def q_slerp(self, q1, q2, t): | |
| + """ | |
| + Spherical linear interpolation from q1 to q2 at t in [0,1]. | |
| + Both q1, q2 assumed normalized. | |
| + q1, q2, t can be 1D or broadcastable shapes, but typically 1D. | |
| + """ | |
| + q1 = self.q_normalize(q1) | |
| + q2 = self.q_normalize(q2) | |
| + dot = (q1 * q2).sum(dim=-1, keepdim=True) # the dot product | |
| + | |
| + eps = 1e-12 | |
| + dot = dot.clamp(-1.0 + eps, 1.0 - eps) | |
| + | |
| + flip_mask = dot < 0.0 | |
| + if flip_mask.any(): | |
| + q2 = torch.where(flip_mask, -q2, q2) | |
| + dot = torch.where(flip_mask, -dot, dot) | |
| + | |
| + # If they're very close, do a simple linear interpolation | |
| + close_mask = dot.squeeze(-1) > 0.9995 | |
| + # Using an epsilon to avoid potential issues close to 1.0 | |
| + | |
| + # Branch 1: Very close | |
| + # linear LERP | |
| + lerp_val = (1.0 - t) * q1 + t * q2 | |
| + lerp_val = self.q_normalize(lerp_val) | |
| + | |
| + # Branch 2: Standard SLERP | |
| + theta_0 = torch.acos(dot) | |
| + sin_theta_0 = torch.sin(theta_0) | |
| + theta = theta_0 * t | |
| + s1 = torch.sin(theta_0 - theta) / sin_theta_0.clamp(min=1e-12) | |
| + s2 = torch.sin(theta) / sin_theta_0.clamp(min=1e-12) | |
| + slerp_val = s1 * q1 + s2 * q2 | |
| + slerp_val = self.q_normalize(slerp_val) | |
| + | |
| + # Combine | |
| + return torch.where( | |
| + close_mask.unsqueeze(-1), | |
| + lerp_val, | |
| + slerp_val | |
| + ) | |
| + | |
| + def compute_uniform_tangent(self, q_im1, q_i, q_ip1): | |
| + """ | |
| + Computes a 'Catmull–Rom-like' tangent T_i for quaternion q_i, | |
| + given neighbors q_im1, q_i, q_ip1. | |
| + | |
| + T_i = q_i * exp( -0.25 * [ log(q_i^-1 q_ip1) + log(q_i^-1 q_im1) ] ) | |
| + """ | |
| + q_im1 = self.q_normalize(q_im1) | |
| + q_i = self.q_normalize(q_i) | |
| + q_ip1 = self.q_normalize(q_ip1) | |
| + | |
| + inv_qi = self.q_inverse(q_i) | |
| + r1 = self.q_multiply(inv_qi, q_ip1) | |
| + r2 = self.q_multiply(inv_qi, q_im1) | |
| + | |
| + lr1 = self.q_log(r1) | |
| + lr2 = self.q_log(r2) | |
| + | |
| + m = -0.25 * (lr1 + lr2) | |
| + exp_m = self.q_exp(m) | |
| + return self.q_multiply(q_i, exp_m) | |
| + | |
| + def compute_all_uniform_tangents(self, quats): | |
| + """ | |
| + Vectorized version that computes tangents T_i for all keyframe quaternions at once. | |
| + quats shape: [N,4], N >= 2 | |
| + Returns shape [N,4]. | |
| + """ | |
| + q_im1 = torch.cat([quats[[0]], quats[:-1]], dim=0) # q_im1[0] = q0 | |
| + q_ip1 = torch.cat([quats[1:], quats[[-1]]], dim=0) # q_ip1[N-1]= q_{N-1} | |
| + | |
| + return self.compute_uniform_tangent(q_im1, quats, q_ip1) | |
| + | |
| + def squad(self, q0, a, b, q1, t): | |
| + """ | |
| + Shoemake's "squad" interpolation for quaternion splines: | |
| + squad(q0, a, b, q1; t) = slerp( slerp(q0, q1; t), | |
| + slerp(a, b; t), | |
| + 2t(1-t) ) | |
| + where a, b are tangential control quaternions for q0, q1. | |
| + """ | |
| + s1 = self.q_slerp(q0, q1, t) | |
| + s2 = self.q_slerp(a, b, t) | |
| + alpha = 2.0*t*(1.0 - t) | |
| + return self.q_slerp(s1, s2, alpha) | |
| + | |
| + def uniform_cr_spline(self, quats, num_samples_per_segment=10): | |
| + """ | |
| + Given a list of keyframe quaternions quats (each a torch 1D tensor [4]), | |
| + compute a "Uniform Catmull–Rom–like" quaternion spline through them. | |
| + | |
| + Returns: | |
| + A list (Python list) of interpolated quaternions (torch tensors), | |
| + including all segment endpoints. | |
| + | |
| + Each interior qi gets a tangent T_i using neighbors q_{i-1}, q_i, q_{i+1}. | |
| + For boundary tangents, we replicate the end quaternions. | |
| + """ | |
| + n = quats.shape[0] | |
| + if n < 2: | |
| + return quats.unsqueeze(0) # not enough quats to interpolate | |
| + | |
| + # Precompute tangents | |
| + tangents = self.compute_all_uniform_tangents(quats) | |
| + | |
| + # Interpolate each segment [qi, q_{i+1}] | |
| + q0 = quats[:-1].unsqueeze(1) | |
| + q1 = quats[1:].unsqueeze(1) | |
| + a = tangents[:-1].unsqueeze(1) | |
| + b = tangents[1:].unsqueeze(1) | |
| + | |
| + t_vals = torch.linspace(0.0, 1.0, num_samples_per_segment, device=quats.device, dtype=quats.dtype) | |
| + t_vals = t_vals.view(1, -1, 1) | |
| + | |
| + out = self.squad(q0, a, b, q1, t_vals) | |
| + return out | |
| + | |
| + | |
| + def forward(self, pred, targ, cond=None, scene_id=None, norm_params=None): | |
| + loss, err_t, err_smooth, err_geo, err_r, err_dtw, err_hausdorff, err_frechet, err_chamfer = self._loss(pred, targ, cond, scene_id, norm_params) | |
| + | |
| + info = { | |
| + 'trans. error': err_t.item(), | |
| + 'smoothness error': err_smooth.item(), | |
| + # 'dtw': err_dtw.item(), | |
| + # 'hausdorff': err_hausdorff.item(), | |
| + # 'frechet': err_frechet.item(), | |
| + # 'chamfer': err_chamfer.item(), | |
| + 'quat. dist.': err_r.item(), | |
| + 'geodesic dist.': err_geo.item(), | |
| + } | |
| + | |
| + return loss, info | |
| + | |
| + def _loss(self, pred, targ, cond=None, scene_id=None, norm_params=None): | |
| + def poly_eval(coeffs, x): | |
| + """ | |
| + Evaluates a polynomial (with highest-degree term first) at points x. | |
| + coeffs: 2D tensor of shape [num_polynomials, degree + 1], highest-degree term first. | |
| + x: 1D tensor of points at which to evaluate the polynomial. | |
| + Returns: | |
| + 2D tensor of shape [num_polynomials, len(x)], containing p(x). | |
| + """ | |
| + x_powers = torch.stack([x**i for i in range(coeffs.shape[-1] - 1, -1, -1)], dim=-1) | |
| + x_powers = x_powers.to(torch.float64).to(coeffs.device) | |
| + y = torch.matmul(coeffs, x_powers.T) | |
| + return y | |
| + | |
| + # Make sure the dtype is float64 | |
| + pred = pred.to(torch.float64) | |
| + targ = targ.to(torch.float64) | |
| + | |
| + # Rescale the translations | |
| + if scene_id is not None and norm_params is not None: | |
| + scene_id = scene_id.item() | |
| + scene_scale = self.scales[str(scene_id)] | |
| + scene_scale = norm_params['scale'][0] * scene_scale | |
| + pred[..., :3] = pred[..., :3] * scene_scale | |
| + targ[..., :3] = targ[..., :3] * scene_scale | |
| + # print(pred[..., :3].max(), targ[..., :3].max()) | |
| + | |
| + # We only consider interpolated points for loss calculation | |
| + candidate_idxs = sorted(cond.keys()) | |
| + pred = pred[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] | |
| + targ = targ[:, candidate_idxs[0] : candidate_idxs[-1] + 1, :] | |
| + | |
| + pred_trans = pred[..., :3] | |
| + pred_quat = pred[..., 3:7] | |
| + targ_trans = targ[..., :3] | |
| + targ_quat = targ[..., 3:7] | |
| + | |
| + pred_coeffs = self.compute_spline_coeffs(pred_trans) | |
| + targ_coeffs = self.compute_spline_coeffs(targ_trans) | |
| + | |
| + n_points = 2000 | |
| + | |
| + # Distribute sample points among intervals | |
| + dists = torch.norm(targ_trans[:, 1:, :] - targ_trans[:, :-1, :], dim=-1).reshape(-1) | |
| + dists_c = torch.zeros(len(candidate_idxs) - 1, device=pred.device) | |
| + for i in range(len(candidate_idxs) - 1): | |
| + dists_c[i] = dists[candidate_idxs[i]:candidate_idxs[i+1]].sum() | |
| + | |
| + weights_c = dists_c / dists_c.sum() | |
| + scaled_c = weights_c * n_points | |
| + points_c = torch.floor(scaled_c).int() | |
| + | |
| + while points_c.sum() < n_points: | |
| + idx = torch.argmax(scaled_c - points_c) | |
| + points_c[idx] += 1 | |
| + | |
| + # Calculate the spline loss | |
| + sample_points = 50 | |
| + x = torch.linspace(0, 1, sample_points, device=pred.device) | |
| + pred_spline = poly_eval(pred_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) | |
| + targ_spline = poly_eval(targ_coeffs, x).permute(0, 1, 3, 2).reshape(-1, sample_points, 3) | |
| + | |
| + indexes = [] | |
| + start_idx = candidate_idxs[0] | |
| + for c, (idx_i0, idx_i1) in enumerate(zip(candidate_idxs[:-1], candidate_idxs[1:])): | |
| + p = points_c[c] | |
| + total_dist = dists_c[c] | |
| + dist_arr = dists[idx_i0 - start_idx : idx_i1 - start_idx] | |
| + | |
| + step_distances = (dist_arr / sample_points).repeat_interleave(sample_points) | |
| + cumul_distances = step_distances.cumsum(dim=0) | |
| + | |
| + dist_per_pick = total_dist / p | |
| + pick_targets = torch.arange(1, p + 1, device=dists.device) * dist_per_pick | |
| + | |
| + pick_idxs = torch.searchsorted(cumul_distances, pick_targets, right=True) | |
| + pick_idxs = torch.clamp(pick_idxs, max=len(cumul_distances) - 1) | |
| + | |
| + | |
| + indexes_1d = torch.zeros_like(step_distances) | |
| + indexes_1d[pick_idxs] = 1 | |
| + | |
| + indexes_2d = indexes_1d.view(len(dist_arr), sample_points) | |
| + | |
| + indexes.append(indexes_2d) | |
| + | |
| + indexes = torch.cat(indexes)[1: -1] # The first and last candidates don't have spline representations | |
| + | |
| + indexes_trans = torch.stack([indexes for _ in range(3)], dim=-1) | |
| + indexes_quat = torch.stack([indexes for _ in range(4)], dim=-1) | |
| + | |
| + indexes_trans = indexes_trans.to(torch.bool) | |
| + indexes_quat = indexes_quat.to(torch.bool) | |
| + | |
| + pred_trans_selected_values = pred_spline[indexes_trans] | |
| + targ_trans_selected_values = targ_spline[indexes_trans] | |
| + | |
| + pred_trans_selected_values = pred_trans_selected_values.reshape(-1, 3) | |
| + targ_trans_selected_values = targ_trans_selected_values.reshape(-1, 3) | |
| + | |
| + # Calculate the loss for quaternions | |
| + pred_quat = pred_quat / pred_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) | |
| + targ_quat = targ_quat / targ_quat.norm(dim=-1, keepdim=True).clamp(min=1e-8) | |
| + | |
| + targ_quat_spline = self.uniform_cr_spline(targ_quat.reshape(-1, 4), num_samples_per_segment=sample_points) | |
| + pred_quat_spline = self.uniform_cr_spline(pred_quat.reshape(-1, 4), num_samples_per_segment=sample_points) | |
| + | |
| + | |
| + targ_quat_spline = targ_quat_spline[1:-1] | |
| + pred_quat_spline = pred_quat_spline[1:-1] | |
| + | |
| + | |
| + pred_quat_selected_values = pred_quat_spline[indexes_quat] | |
| + targ_quat_selected_values = targ_quat_spline[indexes_quat] | |
| + | |
| + pred_quat_selected_values = pred_quat_selected_values.reshape(-1, 4) | |
| + targ_quat_selected_values = targ_quat_selected_values.reshape(-1, 4) | |
| + | |
| + # Calculate the geodesic loss | |
| + pred_rot = quaternion_to_matrix(pred_quat_selected_values).reshape(-1, 3, 3) | |
| + targ_rot = quaternion_to_matrix(targ_quat_selected_values).reshape(-1, 3, 3) | |
| + | |
| + eps = 1e-12 | |
| + r2r1 = pred_rot @ targ_rot.permute(0, 2, 1) | |
| + trace = torch.diagonal(r2r1, dim1=-2, dim2=-1).sum(-1) | |
| + trace = torch.clamp((trace - 1) / 2, -1.0 + eps, 1.0 - eps) | |
| + geodesic_loss = torch.acos(trace).mean() | |
| + | |
| + # Calculate the rotation error | |
| + dot_product = torch.sum(pred_quat_selected_values * targ_quat_selected_values, dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps) | |
| + quaternion_dist = 1 - (dot_product ** 2).mean() | |
| + | |
| + # Calculate the L2 loss | |
| + l2_loss = F.mse_loss(pred_trans_selected_values, targ_trans_selected_values, reduction='mean') | |
| + | |
| + # Calculate the smoothness loss for translation and quaternion | |
| + smoothness_multiplier = 10 ** 2 # Empirically determined multiplier for smoothness loss | |
| + weight_acceleration = 0.1 | |
| + weight_jerk = 0.05 | |
| + | |
| + pos_acc = pred_trans_selected_values[2:, :] - 2 * pred_trans_selected_values[1:-1, :] + pred_trans_selected_values[:-2, :] | |
| + pos_jerk = pred_trans_selected_values[3:, :] - 3 * pred_trans_selected_values[2:-1, :] + 3 * pred_trans_selected_values[1:-2, :] - pred_trans_selected_values[:-3, :] | |
| + | |
| + pos_acceleration_loss = torch.mean(pos_acc ** 2) | |
| + pos_jerk_loss = torch.mean(pos_jerk ** 2) | |
| + | |
| + q0 = pred_quat_selected_values[:-1, :] | |
| + q1 = pred_quat_selected_values[1:, :] | |
| + sign = torch.where((q0 * q1).sum(dim=-1) < 0, -1.0, 1.0) | |
| + q1 = sign.unsqueeze(-1) * q1 | |
| + | |
| + dq = self.q_multiply(q1, self.q_inverse(q0)) | |
| + theta = 2 * torch.acos(torch.clamp(dq[..., 0], -1.0 + 1e-8, 1.0 - 1e-8)) | |
| + | |
| + rot_acc = theta[2:] - 2*theta[1:-1] + theta[:-2] | |
| + rot_jerk = theta[3:] - 3*theta[2:-1] + 3*theta[1:-2] - theta[:-3] | |
| + | |
| + rot_acceleration_loss = torch.mean(rot_acc ** 2) | |
| + rot_jerk_loss = torch.mean(rot_jerk ** 2) | |
| + | |
| + alpha_rot = 0.1 # <-- tune this (e.g. 0.1 … 10) | |
| + | |
| + | |
| + acceleration_loss = pos_acceleration_loss + alpha_rot * rot_acceleration_loss | |
| + jerk_loss = pos_jerk_loss + alpha_rot * rot_jerk_loss | |
| + | |
| + smoothness_loss = ( | |
| + weight_acceleration * acceleration_loss | |
| + + weight_jerk * jerk_loss | |
| + ) * smoothness_multiplier | |
| + | |
| + | |
| + # Calculate the spline loss | |
| + l2_multiplier = 10.0 | |
| + spline_loss = l2_multiplier * (l2_loss + smoothness_loss) + geodesic_loss + quaternion_dist | |
| + | |
| + dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss = None, None, None, None | |
| + | |
| + # Uncomment these lines if you want to use the other losses | |
| + ''' | |
| + dtw = DynamicTimeWarpingLoss() | |
| + dtw_loss, _ = dtw.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) | |
| + | |
| + hausdorff = HausdorffDistanceLoss() | |
| + hausdorff_loss, _ = hausdorff.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) | |
| + | |
| + frec = FrechetDistanceLoss() | |
| + frechet_loss, _ = frec.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) | |
| + | |
| + chamfer = ChamferDistanceLoss() | |
| + chamfer_loss, _ = chamfer.forward(pred_trans_selected_values.reshape(-1, 3), targ_trans_selected_values.reshape(-1, 3)) | |
| + ''' | |
| + | |
| + return spline_loss, l2_multiplier * l2_loss, l2_multiplier * smoothness_loss, geodesic_loss, quaternion_dist, dtw_loss, hausdorff_loss, frechet_loss, chamfer_loss | |
| + | |
| + | |
| +class DynamicTimeWarpingLoss(nn.Module): | |
| + def __init__(self): | |
| + super().__init__() | |
| + | |
| + def _dtw_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Computes the DTW distance between two 2D tensors (T x D), | |
| + where T is sequence length and D is feature dimension. | |
| + """ | |
| + # seq1, seq2 shapes: (time_steps, feature_dim) | |
| + n, m = seq1.size(0), seq2.size(0) | |
| + | |
| + # Cost matrix (pairwise distances between all elements) | |
| + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) | |
| + for i in range(n): | |
| + for j in range(m): | |
| + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) | |
| + | |
| + # Accumulated cost matrix | |
| + dist = torch.full((n + 1, m + 1), float('inf'), | |
| + device=seq1.device, dtype=seq1.dtype) | |
| + dist[0, 0] = 0.0 | |
| + | |
| + # Populate the DP table | |
| + for i in range(1, n + 1): | |
| + for j in range(1, m + 1): | |
| + dist[i, j] = cost[i - 1, j - 1] + torch.min( | |
| + torch.min( | |
| + dist[i - 1, j], # Insertion | |
| + dist[i, j - 1], # Deletion | |
| + ), | |
| + dist[i - 1, j - 1]# Match | |
| + ) | |
| + | |
| + return dist[n, m] | |
| + | |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Compute the average DTW loss over a batch of sequences. | |
| + | |
| + pred, targ shapes: (batch_size, T, D) | |
| + """ | |
| + # Ensure shapes match in batch dimension | |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." | |
| + | |
| + # Compute DTW distance per sample in the batch | |
| + distances = [] | |
| + for b in range(pred.size(0)): | |
| + seq1 = pred[b] | |
| + seq2 = targ[b] | |
| + dtw_val = self._dtw_distance(seq1, seq2) | |
| + distances.append(dtw_val) | |
| + | |
| + # Stack and take mean to get scalar loss | |
| + dtw_loss = torch.stack(distances).mean() | |
| + return dtw_loss | |
| + | |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): | |
| + """ | |
| + Returns a tuple: (loss, info_dict), | |
| + where loss is a scalar tensor and info_dict is a dictionary | |
| + of extra information (e.g., loss components). | |
| + """ | |
| + loss = self._loss(pred, targ) | |
| + | |
| + info = { | |
| + 'dtw': loss.item() | |
| + } | |
| + | |
| + return loss, info | |
| + | |
| +class HausdorffDistanceLoss(nn.Module): | |
| + def __init__(self): | |
| + super().__init__() | |
| + | |
| + def _hausdorff_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Computes the Hausdorff distance between two 2D tensors (N x D), | |
| + where N is the number of points and D is the feature dimension. | |
| + | |
| + The Hausdorff distance H(A,B) between two sets A and B is defined as: | |
| + H(A, B) = max( h(A, B), h(B, A) ), | |
| + where | |
| + h(A, B) = max_{a in A} min_{b in B} d(a, b). | |
| + | |
| + Here, d(a, b) is the Euclidean distance between points a and b. | |
| + """ | |
| + # set1, set2 shapes: (num_points, feature_dim) | |
| + n, m = set1.size(0), set2.size(0) | |
| + | |
| + # Compute pairwise distances | |
| + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) | |
| + for i in range(n): | |
| + for j in range(m): | |
| + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) | |
| + | |
| + # Forward direction: for each point in set1, find distance to closest point in set2 | |
| + forward_min = cost.min(dim=1)[0] # Shape (n,) | |
| + forward_hausdorff = forward_min.max() # max over n | |
| + | |
| + # Backward direction: for each point in set2, find distance to closest point in set1 | |
| + backward_min = cost.min(dim=0)[0] # Shape (m,) | |
| + backward_hausdorff = backward_min.max() # max over m | |
| + | |
| + # Hausdorff distance is the max of the two | |
| + hausdorff_dist = torch.max(forward_hausdorff, backward_hausdorff) | |
| + return hausdorff_dist | |
| + | |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Compute the average Hausdorff distance over a batch of point sets. | |
| + | |
| + pred, targ shapes: (batch_size, N, D) | |
| + """ | |
| + # Ensure shapes match in batch dimension | |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." | |
| + | |
| + distances = [] | |
| + for b in range(pred.size(0)): | |
| + set1 = pred[b] | |
| + set2 = targ[b] | |
| + h_dist = self._hausdorff_distance(set1, set2) | |
| + distances.append(h_dist) | |
| + | |
| + # Stack and take mean to get scalar loss | |
| + hausdorff_loss = torch.stack(distances).mean() | |
| + return hausdorff_loss | |
| + | |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): | |
| + """ | |
| + Returns a tuple: (loss, info_dict), | |
| + where loss is a scalar tensor and info_dict is a dictionary | |
| + of extra information (e.g., distance components). | |
| + """ | |
| + loss = self._loss(pred, targ) | |
| + | |
| + info = { | |
| + 'hausdorff': loss.item() | |
| + } | |
| + | |
| + return loss, info | |
| + | |
| +class FrechetDistanceLoss(nn.Module): | |
| + def __init__(self): | |
| + super().__init__() | |
| + | |
| + def _frechet_distance(self, seq1: torch.Tensor, seq2: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Computes the (discrete) Fréchet distance between two 2D tensors (T x D), | |
| + where T is the sequence length and D is the feature dimension. | |
| + | |
| + The Fréchet distance between two curves in discrete form can be computed | |
| + by filling in a DP table “ca” where: | |
| + | |
| + ca[i, j] = max( d(seq1[i], seq2[j]), | |
| + min(ca[i-1, j], ca[i, j-1], ca[i-1, j-1]) ) | |
| + | |
| + with boundary conditions handled appropriately. | |
| + Here, d(seq1[i], seq2[j]) is the Euclidean distance. | |
| + """ | |
| + n, m = seq1.size(0), seq2.size(0) | |
| + | |
| + # Cost matrix (pairwise distances between all elements) | |
| + cost = torch.zeros(n, m, device=seq1.device, dtype=seq1.dtype) | |
| + for i in range(n): | |
| + for j in range(m): | |
| + cost[i, j] = torch.norm(seq1[i] - seq2[j], p=2) | |
| + | |
| + # DP matrix for the Fréchet distance | |
| + ca = torch.full((n, m), float('inf'), device=seq1.device, dtype=seq1.dtype) | |
| + ca[0, 0] = cost[0, 0] | |
| + | |
| + # Initialize first row | |
| + for i in range(1, n): | |
| + ca[i, 0] = torch.max(ca[i - 1, 0], cost[i, 0]) | |
| + | |
| + # Initialize first column | |
| + for j in range(1, m): | |
| + ca[0, j] = torch.max(ca[0, j - 1], cost[0, j]) | |
| + | |
| + # Populate the DP table | |
| + for i in range(1, n): | |
| + for j in range(1, m): | |
| + ca[i, j] = torch.max( | |
| + cost[i, j], | |
| + torch.min( | |
| + torch.min( | |
| + ca[i - 1, j], | |
| + ca[i, j - 1], | |
| + ), | |
| + ca[i - 1, j - 1] | |
| + ) | |
| + ) | |
| + | |
| + return ca[n - 1, m - 1] | |
| + | |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Compute the average Fréchet distance over a batch of sequences. | |
| + | |
| + pred, targ shapes: (batch_size, T, D) | |
| + """ | |
| + # Ensure shapes match in batch dimension | |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." | |
| + | |
| + distances = [] | |
| + for b in range(pred.size(0)): | |
| + seq1 = pred[b] | |
| + seq2 = targ[b] | |
| + fd_val = self._frechet_distance(seq1, seq2) | |
| + distances.append(fd_val) | |
| + | |
| + # Stack and take mean to get scalar loss | |
| + frechet_loss = torch.stack(distances).mean() | |
| + return frechet_loss | |
| + | |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): | |
| + """ | |
| + Returns a tuple: (loss, info_dict), | |
| + where loss is a scalar tensor and info_dict is a dictionary | |
| + of extra information (e.g., distance components). | |
| + """ | |
| + loss = self._loss(pred, targ) | |
| + info = { | |
| + 'frechet': loss.item() | |
| + } | |
| + return loss, info | |
| + | |
| +class ChamferDistanceLoss(nn.Module): | |
| + def __init__(self): | |
| + super().__init__() | |
| + | |
| + def _chamfer_distance(self, set1: torch.Tensor, set2: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Computes the symmetrical Chamfer distance between | |
| + two 2D tensors (N x D), where N is the number of points | |
| + and D is the feature dimension. | |
| + | |
| + The Chamfer distance between two point sets A and B is often defined as: | |
| + | |
| + d_chamfer(A, B) = 1/|A| ∑_{a ∈ A} min_{b ∈ B} ‖a - b‖₂ | |
| + + 1/|B| ∑_{b ∈ B} min_{a ∈ A} ‖b - a‖₂, | |
| + | |
| + where ‖·‖₂ is the Euclidean distance. | |
| + """ | |
| + # set1, set2 shapes: (num_points, feature_dim) | |
| + n, m = set1.size(0), set2.size(0) | |
| + | |
| + cost = torch.zeros(n, m, device=set1.device, dtype=set1.dtype) | |
| + for i in range(n): | |
| + for j in range(m): | |
| + cost[i, j] = torch.norm(set1[i] - set2[j], p=2) | |
| + | |
| + # For each point in set1, find distance to the closest point in set2 | |
| + forward_min = cost.min(dim=1)[0] # shape: (n,) | |
| + forward_mean = forward_min.mean() | |
| + | |
| + # For each point in set2, find distance to the closest point in set1 | |
| + backward_min = cost.min(dim=0)[0] # shape: (m,) | |
| + backward_mean = backward_min.mean() | |
| + | |
| + chamfer_dist = forward_mean + backward_mean | |
| + return chamfer_dist | |
| + | |
| + def _loss(self, pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor: | |
| + """ | |
| + Compute the average Chamfer distance over a batch of point sets. | |
| + | |
| + pred, targ shapes: (batch_size, N, D) | |
| + """ | |
| + # Ensure shapes match in batch dimension | |
| + assert pred.size(0) == targ.size(0), "Batch sizes must match." | |
| + | |
| + distances = [] | |
| + for b in range(pred.size(0)): | |
| + set1 = pred[b] | |
| + set2 = targ[b] | |
| + distance_val = self._chamfer_distance(set1, set2) | |
| + distances.append(distance_val) | |
| + | |
| + # Combine into a single scalar | |
| + chamfer_loss = torch.stack(distances).mean() | |
| + return chamfer_loss | |
| + | |
| + def forward(self, pred: torch.Tensor, targ: torch.Tensor): | |
| + """ | |
| + Returns a tuple: (loss, info_dict), | |
| + where 'loss' is a scalar tensor and 'info_dict' is a dictionary | |
| + of extra information (e.g., distance components). | |
| + """ | |
| + loss = self._loss(pred, targ) | |
| + info = { | |
| + 'chamfer': loss.item() | |
| + } | |
| + return loss, info | |
| + | |
| + | |
| +def slerp(q1, q2, t): | |
| + """Spherical linear interpolation between two quaternions.""" | |
| + q1 = q1 / np.linalg.norm(q1) | |
| + q2 = q2 / np.linalg.norm(q2) | |
| + dot = np.dot(q1, q2) | |
| + | |
| + if dot < 0.0: | |
| + q2 = -q2 | |
| + dot = -dot | |
| + # If dot is very close to 1, use linear interpolation | |
| + | |
| + if dot > 0.9995: | |
| + result = q1 + t * (q2 - q1) | |
| + result = result / np.linalg.norm(result) | |
| + return result | |
| + | |
| + theta_0 = np.arccos(dot) | |
| + theta = theta_0 * t | |
| + | |
| + q3 = q2 - q1 * dot | |
| + q3 = q3 / np.linalg.norm(q3) | |
| + return q1 * np.cos(theta) + q3 * np.sin(theta) | |
| + | |
| +def catmull_rom_spline_with_rotation(control_points, timepoints, horizon): | |
| + """Compute Catmull-Rom spline for both position and quaternion rotation.""" | |
| + spline_points = [] | |
| + # Extrapolate the initial points | |
| + if timepoints[0] != 0: | |
| + for t in range(timepoints[0]): | |
| + x = control_points[0][0] | |
| + y = control_points[0][1] | |
| + z = control_points[0][2] | |
| + q = control_points[0][3:7] | |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) | |
| + | |
| + #Linear interpolate between 0th and 1th control points | |
| + for t in np.linspace(0, 1, timepoints[1] - timepoints[0] + 1): | |
| + x = control_points[0][0] + t * (control_points[1][0] - control_points[0][0]) | |
| + y = control_points[0][1] + t * (control_points[1][1] - control_points[0][1]) | |
| + z = control_points[0][2] + t * (control_points[1][2] - control_points[0][2]) | |
| + q = slerp(control_points[0][3:7], control_points[1][3:7], t) | |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) | |
| + | |
| + | |
| + # Iterate over the control points | |
| + for i in range(1, len(control_points) - 2): | |
| + P0 = control_points[i-1][:3] | |
| + P1 = control_points[i][:3] | |
| + P2 = control_points[i+1][:3] | |
| + P3 = control_points[i+2][:3] | |
| + Q0 = control_points[i-1][3:7] | |
| + Q1 = control_points[i][3:7] | |
| + Q2 = control_points[i+1][3:7] | |
| + Q3 = control_points[i+2][3:7] | |
| + | |
| + # Interpolate position (using Catmull-Rom spline) | |
| + for idx, t in enumerate(np.linspace(0, 1, timepoints[i+1] - timepoints[i] + 1)): | |
| + if idx == 0: | |
| + continue | |
| + | |
| + x = 0.5 * ((2 * P1[0]) + (-P0[0] + P2[0]) * t + | |
| + (2 * P0[0] - 5 * P1[0] + 4 * P2[0] - P3[0]) * t**2 + | |
| + (-P0[0] + 3 * P1[0] - 3 * P2[0] + P3[0]) * t**3) | |
| + y = 0.5 * ((2 * P1[1]) + (-P0[1] + P2[1]) * t + | |
| + (2 * P0[1] - 5 * P1[1] + 4 * P2[1] - P3[1]) * t**2 + | |
| + (-P0[1] + 3 * P1[1] - 3 * P2[1] + P3[1]) * t**3) | |
| + z = 0.5 * ((2 * P1[2]) + (-P0[2] + P2[2]) * t + | |
| + (2 * P0[2] - 5 * P1[2] + 4 * P2[2] - P3[2]) * t**2 + | |
| + (-P0[2] + 3 * P1[2] - 3 * P2[2] + P3[2]) * t**3) | |
| + q = slerp(Q1, Q2, t) | |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) | |
| + | |
| + #Linear interpolate between 2nd last and last control points | |
| + for idx, t in enumerate(np.linspace(0, 1, timepoints[-1] - timepoints[-2] + 1)): | |
| + if idx == 0: | |
| + continue | |
| + x = control_points[-2][0] + t * (control_points[-1][0] - control_points[-2][0]) | |
| + y = control_points[-2][1] + t * (control_points[-1][1] - control_points[-2][1]) | |
| + z = control_points[-2][2] + t * (control_points[-1][2] - control_points[-2][2]) | |
| + q = slerp(control_points[-2][3:7], control_points[-1][3:7], t) | |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) | |
| + | |
| + # Extrapolate the rest of the points | |
| + if timepoints[-1] != horizon: | |
| + for t in range(timepoints[-1] + 1, horizon): | |
| + x = control_points[-1][0] | |
| + y = control_points[-1][1] | |
| + z = control_points[-1][2] | |
| + q = control_points[-1][3:7] | |
| + spline_points.append(np.concatenate([np.array([x, y, z]), q])) | |
| + | |
| + stacked_spline_points = np.stack(spline_points, axis=0) | |
| + | |
| + if control_points.shape[1] != 7: | |
| + stacked_spline_points = np.concatenate([stacked_spline_points, np.zeros((stacked_spline_points.shape[0], 1))], axis=1) | |
| + | |
| + | |
| + return stacked_spline_points | |
| + | |
| +def catmull_rom_loss(trajectories, conditions, loss_fc): | |
| + ''' | |
| + loss for catmull-rom interpolation | |
| + ''' | |
| + batch_size, horizon, transition = trajectories.shape | |
| + | |
| + # Extract known indices and values | |
| + known_indices = np.array(list(conditions.keys()), dtype=int) | |
| + | |
| + # candidate_no x batch_size x dim | |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) | |
| + known_values = np.moveaxis(known_values, 0, 1) | |
| + | |
| + # Sort the timepoints | |
| + sorted_indices = np.argsort(known_indices) | |
| + known_indices = known_indices[sorted_indices] | |
| + known_values = known_values[:, sorted_indices] | |
| + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) | |
| + | |
| + # Convert to tensor and move to the same device as trajectories | |
| + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) | |
| + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" | |
| + return loss_fc(spline_points, trajectories) | |
| + | |
| Losses = { | |
| 'l1': WeightedL1, | |
| 'l2': WeightedL2, | |
| 'value_l1': ValueL1, | |
| 'value_l2': ValueL2, | |
| + 'geodesic_l2': GeodesicL2Loss, | |
| + 'rotation_translation': RotationTranslationLoss, | |
| + 'spline': SplineLoss, | |
| } | |
| diff --git a/diffuser/models/temporal.py b/diffuser/models/temporal.py | |
| index e0b9e5c..0f7854a 100644 | |
| --- a/diffuser/models/temporal.py | |
| +++ b/diffuser/models/temporal.py | |
| @@ -17,18 +17,18 @@ class ResidualTemporalBlock(nn.Module): | |
| super().__init__() | |
| self.blocks = nn.ModuleList([ | |
| - Conv1dBlock(inp_channels, out_channels, kernel_size), | |
| - Conv1dBlock(out_channels, out_channels, kernel_size), | |
| + Conv1dBlock(inp_channels, out_channels, kernel_size).to(dtype=torch.float64), | |
| + Conv1dBlock(out_channels, out_channels, kernel_size).to(dtype=torch.float64), | |
| ]) | |
| self.time_mlp = nn.Sequential( | |
| nn.Mish(), | |
| - nn.Linear(embed_dim, out_channels), | |
| + nn.Linear(embed_dim, out_channels).to(dtype=torch.float64), | |
| Rearrange('batch t -> batch t 1'), | |
| - ) | |
| + ).to(dtype=torch.float64) | |
| - self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ | |
| - if inp_channels != out_channels else nn.Identity() | |
| + self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1).to(dtype=torch.float64) \ | |
| + if inp_channels != out_channels else nn.Identity().to(dtype=torch.float64) | |
| def forward(self, x, t): | |
| ''' | |
| @@ -37,7 +37,8 @@ class ResidualTemporalBlock(nn.Module): | |
| returns: | |
| out : [ batch_size x out_channels x horizon ] | |
| ''' | |
| - out = self.blocks[0](x) + self.time_mlp(t) | |
| + | |
| + out = self.blocks[0](x) + self.time_mlp(t.double()) | |
| out = self.blocks[1](out) | |
| return out + self.residual_conv(x) | |
| @@ -49,11 +50,11 @@ class TemporalUnet(nn.Module): | |
| transition_dim, | |
| cond_dim, | |
| dim=32, | |
| - dim_mults=(1, 2, 4, 8), | |
| + dim_mults=(1, 2, 4), | |
| ): | |
| super().__init__() | |
| - dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | |
| + dims = [(transition_dim + cond_dim), *map(lambda m: dim * m, dim_mults)] | |
| in_out = list(zip(dims[:-1], dims[1:])) | |
| print(f'[ models/temporal ] Channel dimensions: {in_out}') | |
| @@ -100,7 +101,7 @@ class TemporalUnet(nn.Module): | |
| self.final_conv = nn.Sequential( | |
| Conv1dBlock(dim, dim, kernel_size=5), | |
| - nn.Conv1d(dim, transition_dim, 1), | |
| + nn.Conv1d(dim, transition_dim, 1).to(dtype=torch.float64), | |
| ) | |
| def forward(self, x, cond, time): | |
| @@ -129,7 +130,6 @@ class TemporalUnet(nn.Module): | |
| x = upsample(x) | |
| x = self.final_conv(x) | |
| - | |
| x = einops.rearrange(x, 'b t h -> b h t') | |
| return x | |
| diff --git a/diffuser/utils/arrays.py b/diffuser/utils/arrays.py | |
| index c3a9d24..96a7093 100644 | |
| --- a/diffuser/utils/arrays.py | |
| +++ b/diffuser/utils/arrays.py | |
| @@ -54,7 +54,7 @@ def batchify(batch): | |
| 1) converting np arrays to torch tensors and | |
| 2) and ensuring that everything has a batch dimension | |
| ''' | |
| - fn = lambda x: to_torch(x[None]) | |
| + fn = lambda x: to_torch(x[None], dtype=torch.float64) | |
| batched_vals = [] | |
| for field in batch._fields: | |
| diff --git a/diffuser/utils/serialization.py b/diffuser/utils/serialization.py | |
| index 6cc9db9..039eb64 100644 | |
| --- a/diffuser/utils/serialization.py | |
| +++ b/diffuser/utils/serialization.py | |
| @@ -19,7 +19,7 @@ def mkdir(savepath): | |
| return False | |
| def get_latest_epoch(loadpath): | |
| - states = glob.glob1(os.path.join(*loadpath), 'state_*') | |
| + states = glob.glob1(os.path.join(loadpath), 'state_*') | |
| latest_epoch = -1 | |
| for state in states: | |
| epoch = int(state.replace('state_', '').replace('.pt', '')) | |
| diff --git a/diffuser/utils/training.py b/diffuser/utils/training.py | |
| index be3556e..c21e0f0 100644 | |
| --- a/diffuser/utils/training.py | |
| +++ b/diffuser/utils/training.py | |
| @@ -4,16 +4,24 @@ import numpy as np | |
| import torch | |
| import einops | |
| import pdb | |
| +from tqdm import tqdm | |
| +import wandb | |
| +from pytorch3d.transforms import axis_angle_to_quaternion | |
| from .arrays import batch_to_device, to_np, to_device, apply_dict | |
| from .timer import Timer | |
| from .cloud import sync_logs | |
| +from ..models.helpers import catmull_rom_spline_with_rotation | |
| def cycle(dl): | |
| while True: | |
| for data in dl: | |
| yield data | |
| +def assert_no_nan_weights(model): | |
| + for name, param in model.named_parameters(): | |
| + assert not torch.isnan(param).any(), f"NaN detected in parameter: {name}" | |
| + | |
| class EMA(): | |
| ''' | |
| empirical moving average | |
| @@ -71,13 +79,35 @@ class Trainer(object): | |
| self.gradient_accumulate_every = gradient_accumulate_every | |
| self.dataset = dataset | |
| - self.dataloader = cycle(torch.utils.data.DataLoader( | |
| - self.dataset, batch_size=train_batch_size, num_workers=1, shuffle=True, pin_memory=True | |
| + dataset_size = len(self.dataset) | |
| + | |
| + # Read the indices from the .txt file | |
| + with open(os.path.join(results_folder, 'train_indices.txt'), 'r') as f: | |
| + self.train_indices = f.read() | |
| + self.train_indices = [int(i) for i in self.train_indices.split('\n') if i] | |
| + | |
| + with open(os.path.join(results_folder, 'val_indices.txt'), 'r') as f: | |
| + self.val_indices = f.read() | |
| + self.val_indices = [int(i) for i in self.val_indices.split('\n') if i] | |
| + | |
| + | |
| + self.train_dataset = torch.utils.data.Subset(self.dataset, self.train_indices) | |
| + self.val_dataset = torch.utils.data.Subset(self.dataset, self.val_indices) | |
| + self.train_dataloader = cycle(torch.utils.data.DataLoader( | |
| + self.train_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False | |
| + )) | |
| + | |
| + self.val_dataloader = cycle(torch.utils.data.DataLoader( | |
| + self.val_dataset, batch_size=train_batch_size, num_workers=1, pin_memory=True, shuffle=False | |
| )) | |
| + | |
| self.dataloader_vis = cycle(torch.utils.data.DataLoader( | |
| self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True | |
| )) | |
| self.renderer = renderer | |
| + | |
| + | |
| + | |
| self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr) | |
| self.logdir = results_folder | |
| @@ -88,6 +118,8 @@ class Trainer(object): | |
| self.reset_parameters() | |
| self.step = 0 | |
| + self.log_to_wandb = False | |
| + | |
| def reset_parameters(self): | |
| self.ema_model.load_state_dict(self.model.state_dict()) | |
| @@ -102,36 +134,129 @@ class Trainer(object): | |
| #-----------------------------------------------------------------------------# | |
| def train(self, n_train_steps): | |
| - | |
| + # Save the indices as .txt files | |
| + with open(os.path.join(self.logdir, 'train_indices.txt'), 'w') as f: | |
| + for idx in self.train_indices: | |
| + f.write(f"{idx}\n") | |
| + with open(os.path.join(self.logdir, 'val_indices.txt'), 'w') as f: | |
| + for idx in self.val_indices: | |
| + f.write(f"{idx}\n") | |
| + | |
| timer = Timer() | |
| - for step in range(n_train_steps): | |
| + torch.autograd.set_detect_anomaly(True) | |
| + | |
| + # Setup wandb | |
| + if self.log_to_wandb: | |
| + wandb.init( | |
| + project='trajectory-generation', | |
| + config={'lr': self.optimizer.param_groups[0]['lr'], 'batch_size': self.batch_size, 'gradient_accumulate_every': self.gradient_accumulate_every}, | |
| + ) | |
| + | |
| + for step in tqdm(range(n_train_steps)): | |
| + | |
| + mean_train_loss = 0.0 | |
| for i in range(self.gradient_accumulate_every): | |
| - batch = next(self.dataloader) | |
| + batch = next(self.train_dataloader) | |
| batch = batch_to_device(batch) | |
| - | |
| - loss, infos = self.model.loss(*batch) | |
| + | |
| + loss, infos = self.model.loss(x=batch.trajectories, cond=batch.conditions) | |
| loss = loss / self.gradient_accumulate_every | |
| + mean_train_loss += loss.item() | |
| loss.backward() | |
| + if self.log_to_wandb: | |
| + wandb.log({ | |
| + 'step': self.step, | |
| + 'train/loss': mean_train_loss | |
| + }) | |
| + | |
| + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) | |
| + | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| + assert_no_nan_weights(self.model) | |
| + | |
| if self.step % self.update_ema_every == 0: | |
| self.step_ema() | |
| if self.step % self.save_freq == 0: | |
| - label = self.step // self.label_freq * self.label_freq | |
| + label = self.step | |
| + print(f'Saving model at step {self.step}...') | |
| self.save(label) | |
| if self.step % self.log_freq == 0: | |
| - infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) | |
| - print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}') | |
| + val_losses = [] | |
| + lin_int_losses = [] | |
| + | |
| + val_infos_list = [] | |
| + lin_int_infos_list = [] | |
| + | |
| + catmull_losses = [] | |
| + catmull_infos_list = [] | |
| + | |
| + for _ in range(len(self.val_indices)): | |
| + val_batch = next(self.val_dataloader) | |
| + val_batch = batch_to_device(val_batch) | |
| + | |
| + traj = self.model.forward(val_batch.conditions, horizon=val_batch.trajectories.shape[1]) | |
| + val_loss, val_infos = self.model.loss_fn(traj, val_batch.trajectories, cond=val_batch.conditions) | |
| + | |
| + val_losses.append(val_loss.item()) | |
| + val_infos_list.append({key: val for key, val in val_infos.items()}) | |
| + | |
| + | |
| + (lin_int_loss, lin_int_infos), lin_int_traj = self.linear_interpolation_loss( | |
| + val_batch.trajectories, val_batch.conditions, self.model.loss_fn | |
| + ) | |
| + lin_int_losses.append(lin_int_loss.item()) | |
| + lin_int_infos_list.append({key: val for key, val in lin_int_infos.items()}) | |
| + | |
| + (catmull_loss, catmull_infos), catmull_traj = self.catmull_rom_loss( | |
| + val_batch.trajectories, val_batch.conditions, self.model.loss_fn | |
| + ) | |
| + | |
| + catmull_losses.append(catmull_loss.item()) | |
| + catmull_infos_list.append(catmull_infos) | |
| + | |
| + avg_val_loss = np.mean(val_losses) | |
| + avg_lin_int_loss = np.mean(lin_int_losses) | |
| + | |
| + val_infos = {key: np.mean([info[key] for info in val_infos_list]) for key in val_infos_list[0].keys()} | |
| + lin_int_infos = {key: np.mean([info[key] for info in lin_int_infos_list]) for key in lin_int_infos_list[0].keys()} | |
| - if self.step == 0 and self.sample_freq: | |
| - self.render_reference(self.n_reference) | |
| + avg_catmull_loss = np.mean(catmull_losses) | |
| + catmull_infos = {key: np.mean([info[key] for info in catmull_infos_list]) for key in catmull_infos_list[0].keys()} | |
| - if self.sample_freq and self.step % self.sample_freq == 0: | |
| - self.render_samples(n_samples=self.n_samples) | |
| + val_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in val_infos.items()]) | |
| + lin_int_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in lin_int_infos.items()]) | |
| + catmull_infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in catmull_infos.items()]) | |
| + | |
| + | |
| + infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) | |
| + print("Learning Rate: ", self.optimizer.param_groups[0]['lr']) | |
| + print(f'Step {self.step}: {loss * self.gradient_accumulate_every:8.4f} | {infos_str} | t: {timer():8.4f}') | |
| + print(f'Validation - {self.step}: {avg_val_loss:8.4f} | {val_infos_str} | t: {timer():8.4f}') | |
| + print(f'Linear Interpolation Loss - {self.step}: {avg_lin_int_loss:8.4f} | {lin_int_infos_str} | t: {timer():8.4f}') | |
| + print(f'Catmull Rom Loss - {self.step}: {avg_catmull_loss:8.4f} | {catmull_infos_str} | t: {timer():8.4f}') | |
| + print() | |
| + | |
| + if self.log_to_wandb: | |
| + wandb.log({ | |
| + 'step': self.step, | |
| + 'val/loss': avg_val_loss, | |
| + 'val/linear_interp/loss': avg_lin_int_loss, | |
| + 'val/linear_interp/quaternion dist.': lin_int_infos['quat. dist.'], | |
| + 'val/linear_interp/euclidean dist.': lin_int_infos['trans. error'], | |
| + 'val/linear_interp/geodesic loss': lin_int_infos['geodesic dist.'], | |
| + 'val/catmull_rom/loss': avg_catmull_loss, | |
| + 'val/catmull_rom/quaternion dist.': catmull_infos['quat. dist.'], | |
| + 'val/catmull_rom/euclidean dist.': catmull_infos['trans. error'], | |
| + 'val/catmull_rom/geodesic loss': catmull_infos['geodesic dist.'], | |
| + 'val/quaternion dist.': val_infos['quat. dist.'], | |
| + 'val/euclidean dist.': val_infos['trans. error'], | |
| + 'val/geodesic loss': val_infos['geodesic dist.'], | |
| + }) | |
| self.step += 1 | |
| @@ -186,15 +311,6 @@ class Trainer(object): | |
| normed_observations = trajectories[:, :, self.dataset.action_dim:] | |
| observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') | |
| - # from diffusion.datasets.preprocessing import blocks_cumsum_quat | |
| - # # observations = conditions + blocks_cumsum_quat(deltas) | |
| - # observations = conditions + deltas.cumsum(axis=1) | |
| - | |
| - #### @TODO: remove block-stacking specific stuff | |
| - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka | |
| - # observations = blocks_add_kuka(observations) | |
| - #### | |
| - | |
| savepath = os.path.join(self.logdir, f'_sample-reference.png') | |
| self.renderer.composite(savepath, observations) | |
| @@ -225,9 +341,6 @@ class Trainer(object): | |
| # [ 1 x 1 x observation_dim ] | |
| normed_conditions = to_np(batch.conditions[0])[:,None] | |
| - # from diffusion.datasets.preprocessing import blocks_cumsum_quat | |
| - # observations = conditions + blocks_cumsum_quat(deltas) | |
| - # observations = conditions + deltas.cumsum(axis=1) | |
| ## [ n_samples x (horizon + 1) x observation_dim ] | |
| normed_observations = np.concatenate([ | |
| @@ -238,10 +351,70 @@ class Trainer(object): | |
| ## [ n_samples x (horizon + 1) x observation_dim ] | |
| observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') | |
| - #### @TODO: remove block-stacking specific stuff | |
| - # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka | |
| - # observations = blocks_add_kuka(observations) | |
| - #### | |
| - | |
| savepath = os.path.join(self.logdir, f'sample-{self.step}-{i}.png') | |
| self.renderer.composite(savepath, observations) | |
| + | |
| + def linear_interpolation_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): | |
| + batch_size, horizon, transition = trajectories.shape | |
| + | |
| + # Extract known indices and values | |
| + known_indices = np.array(list(conditions.keys()), dtype=int) | |
| + # candidate_no x batch_size x dim | |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) | |
| + known_values = np.moveaxis(known_values, 0, 1) | |
| + | |
| + # Create time steps for interpolation | |
| + time_steps = np.linspace(0, horizon, num=horizon) | |
| + | |
| + # Perform interpolation across all dimensions at once | |
| + linear_int_arr = np.array([[ | |
| + np.interp(time_steps, known_indices, known_values[b, :, dim]) | |
| + for dim in range(transition)] | |
| + for b in range(batch_size)] | |
| + ).T # Transpose to match shape (horizon, transition) | |
| + | |
| + # Convert to tensor and move to the same device as trajectories | |
| + linear_int_arr = np.transpose(linear_int_arr, axes=[2, 0, 1]) | |
| + linear_int_tensor = torch.tensor(linear_int_arr, dtype=torch.float64, device=trajectories.device) | |
| + | |
| + return loss_fc(linear_int_tensor, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), linear_int_tensor | |
| + | |
| + | |
| + def catmull_rom_loss(self, trajectories, conditions, loss_fc, scene_id=None, norm_params=None): | |
| + ''' | |
| + loss for catmull-rom interpolation | |
| + ''' | |
| + | |
| + batch_size, horizon, transition = trajectories.shape | |
| + | |
| + # Extract known indices and values | |
| + known_indices = np.array(list(conditions.keys()), dtype=int) | |
| + # candidate_no x batch_size x dim | |
| + known_values = np.stack([c.cpu().numpy() for c in conditions.values()], axis=0) | |
| + known_values = np.moveaxis(known_values, 0, 1) | |
| + | |
| + # Sort the timepoints | |
| + sorted_indices = np.argsort(known_indices) | |
| + known_indices = known_indices[sorted_indices] | |
| + known_values = known_values[:, sorted_indices] | |
| + | |
| + spline_points = np.array([catmull_rom_spline_with_rotation(known_values[b], known_indices, horizon) for b in range(batch_size)]) | |
| + | |
| + # Convert to tensor and move to the same device as trajectories | |
| + spline_points = torch.tensor(spline_points, dtype=torch.float64, device=trajectories.device) | |
| + | |
| + assert spline_points.shape == trajectories.shape, f"Shape mismatch: {spline_points.shape} != {trajectories.shape}" | |
| + | |
| + return loss_fc(spline_points, trajectories, cond=conditions, scene_id=scene_id, norm_params=norm_params), spline_points | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| + | |
| diff --git a/scripts/train.py b/scripts/train.py | |
| index 2c5f299..6728d6f 100644 | |
| --- a/scripts/train.py | |
| +++ b/scripts/train.py | |
| @@ -108,6 +108,7 @@ utils.report_parameters(model) | |
| print('Testing forward...', end=' ', flush=True) | |
| batch = utils.batchify(dataset[0]) | |
| + | |
| loss, _ = diffusion.loss(*batch) | |
| loss.backward() | |
| print('✓') |