| import torch | |
| class VerletStandardizer(): | |
| def __init__(self, max_dist=50): | |
| super().__init__() | |
| self.max_dist = max_dist # magic | |
| def transform_features(self, trajectory, history): | |
| return trajectory | |
| # trajectory = trajectory.reshape(trajectory.shape[0], -1, 3) | |
| # | |
| # # Apply Verlet parameterization | |
| # full_trajectory = torch.cat([history, trajectory], dim=1) | |
| # deltas = torch.diff(full_trajectory, dim=1)[:, :-1] | |
| # pred_trajectory = full_trajectory[:, 1:-1] + deltas | |
| # actions = full_trajectory[:, 2:] - pred_trajectory | |
| # | |
| # # Standardize actions | |
| # actions = actions * self.max_dist | |
| # | |
| # actions = actions.reshape(actions.shape[0], -1) | |
| # return actions | |
| def untransform_features(self, actions, history): | |
| return actions | |
| # actions = actions.reshape(actions.shape[0], -1, 3) | |
| # | |
| # # Unstandardize actions | |
| # actions = actions / self.max_dist | |
| # | |
| # # Use Verlet parameterization to calculate trajectory | |
| # states = [history[:, 0], history[:, 1]] | |
| # for t in range(actions.shape[1]): | |
| # states.append((2 * states[-1]) - states[-2] + actions[:, t]) | |
| # trajectory = torch.stack(states[2:], dim=1) | |
| # | |
| # trajectory = trajectory.reshape(trajectory.shape[0], -1) | |
| # return trajectory | |