Weiyu Liu
compute rot 6d does not depend on cuda
3827c6d
raw
history blame
2.81 kB
import os
import torch
from StructDiffusion.utils.rotation_continuity import compute_rotation_matrix_from_ortho6d
def get_diffusion_variables_from_9D_actions(struct_xyztheta_inputs, obj_xyztheta_inputs):
# important: we need to get the first two columns, not first two rows
# array([[ 3, 4, 5],
# [ 6, 7, 8],
# [ 9, 10, 11]])
xyz_6d_idxs = [0, 1, 2, 3, 6, 9, 4, 7, 10]
# print(batch_data["obj_xyztheta_inputs"].shape)
# print(batch_data["struct_xyztheta_inputs"].shape)
# only get the first and second columns of rotation
obj_xyztheta_inputs = obj_xyztheta_inputs[:, :, xyz_6d_idxs] # B, N, 9
struct_xyztheta_inputs = struct_xyztheta_inputs[:, :, xyz_6d_idxs] # B, 1, 9
x = torch.cat([struct_xyztheta_inputs, obj_xyztheta_inputs], dim=1) # B, 1 + N, 9
# print(x.shape)
return x
def get_diffusion_variables_from_H(poses):
"""
[[0,1,2,3],
[4,5,6,7],
[8,9,10,11],
[12,13,14,15]
:param obj_xyztheta_inputs: B, N, 4, 4
:return:
"""
xyz_6d_idxs = [3, 7, 11, 0, 4, 8, 1, 5, 9]
B, N, _, _ = poses.shape
x = poses.reshape(B, N, 16)[:, :, xyz_6d_idxs] # B, N, 9
return x
def get_struct_objs_poses(x):
device = x.device
# important: the noisy x can go out of bounds
x = torch.clamp(x, min=-1, max=1)
# x: B, 1 + N, 9
B = x.shape[0]
N = x.shape[1] - 1
# compute_rotation_matrix_from_ortho6d takes in [B, 6], outputs [B, 3, 3]
x_6d = x[:, :, 3:].reshape(-1, 6)
x_rot = compute_rotation_matrix_from_ortho6d(x_6d).reshape(B, N+1, 3, 3) # B, 1 + N, 3, 3
x_trans = x[:, :, :3] # B, 1 + N, 3
x_full = torch.eye(4).repeat(B, 1 + N, 1, 1).to(device)
x_full[:, :, :3, :3] = x_rot
x_full[:, :, :3, 3] = x_trans
struct_pose = x_full[:, 0].unsqueeze(1) # B, 1, 4, 4
pc_poses_in_struct = x_full[:, 1:] # B, N, 4, 4
return struct_pose, pc_poses_in_struct
def compute_current_and_goal_pc_poses(obj_xyzs, struct_pose, pc_poses_in_struct):
device = obj_xyzs.device
# obj_xyzs: B, N, P, 3
# struct_pose: B, 1, 4, 4
# pc_poses_in_struct: B, N, 4, 4
B, N, _, _ = pc_poses_in_struct.shape
_, _, P, _ = obj_xyzs.shape
current_pc_poses = torch.eye(4).repeat(B, N, 1, 1).to(device) # B, N, 4, 4
# print(torch.mean(obj_xyzs, dim=2).shape)
current_pc_poses[:, :, :3, 3] = torch.mean(obj_xyzs, dim=2) # B, N, 4, 4
struct_pose = struct_pose.repeat(1, N, 1, 1) # B, N, 4, 4
struct_pose = struct_pose.reshape(B * N, 4, 4) # B x 1, 4, 4
pc_poses_in_struct = pc_poses_in_struct.reshape(B * N, 4, 4) # B x N, 4, 4
goal_pc_poses = struct_pose @ pc_poses_in_struct # B x N, 4, 4
goal_pc_poses = goal_pc_poses.reshape(B, N, 4, 4) # B, N, 4, 4
return current_pc_poses, goal_pc_poses