Spaces:
Paused
Paused
import os | |
import torch | |
from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d | |
def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs): | |
# important: we need to get the first two columns, not first two rows | |
# array([[ 3, 4, 5], | |
# [ 6, 7, 8], | |
# [ 9, 10, 11]]) | |
xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10] | |
# print(batch_data["obj_xyztheta_inputs"].shape) | |
# print(batch_data["struct_xyztheta_inputs"].shape) | |
# only get the first and second columns of rotation | |
obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9 | |
struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9 | |
x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9 | |
# print(x.shape) | |
return x | |
def get_diffusion_variables_from_H(poses): | |
""" | |
[[0,1,2,3], | |
[4,5,6,7], | |
[8,9,10,11], | |
[12,13,14,15] | |
:param obj_xyztheta_inputs: B, N, 4, 4 | |
:return: | |
""" | |
xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9] | |
B, N, _, _ = poses.shape | |
x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9 | |
return x | |
def get_struct_objs_poses(x): | |
device = x.device | |
# important: the noisy x can go out of bounds | |
x = torch.clamp(x, min=-1, max=1) | |
# x: B, 1 + N, 9 | |
B = x.shape[0] | |
N = x.shape[1] - 1 | |
# compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3] | |
x_6d = x[:, :, 3:].reshape(-1, 6) | |
x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3 | |
x_trans = x[:, :, :3] # B, 1 + N, 3 | |
x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device) | |
x_full[:, :, :3, :3] = x_rot | |
x_full[:, :, :3, 3] = x_trans | |
struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4 | |
pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4 | |
return struct_pose, pc_poses_in_struct | |
def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct): | |
device = obj_xyzs.device | |
# obj_xyzs: B, N, P, 3 | |
# struct_pose: B, 1, 4, 4 | |
# pc_poses_in_struct: B, N, 4, 4 | |
B, N, _, _ = pc_poses_in_struct.shape | |
_, _, P, _ = obj_xyzs.shape | |
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4 | |
# print(torch.mean(obj_xyzs, dim=2).shape) | |
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4 | |
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4 | |
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4 | |
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4 | |
goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4 | |
goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4 | |
return current_pc_poses, goal_pc_poses |