diff --git "a/models/dyn_model_act_v2.py" "b/models/dyn_model_act_v2.py"
new file mode 100644--- /dev/null
+++ "b/models/dyn_model_act_v2.py"
@@ -0,0 +1,2712 @@
+
+import math
+# import torch
+# from ..utils import Timer
+import numpy as np
+# import torch.nn.functional as F
+import os
+
+import argparse
+
+from xml.etree.ElementTree import ElementTree
+
+import trimesh
+import torch
+import torch.nn as nn
+# import List
+# class link; joint; body
+###
+
+from scipy.spatial.transform import Rotation as R
+from torch.distributions.uniform import Uniform
+# deformable articulated objects with the articulated models #
+
+DAMPING = 1.0
+DAMPING = 0.3
+
+urdf_fn = ""
+
+def plane_rotation_matrix_from_angle_xz(angle):
+ sin_ = torch.sin(angle)
+ cos_ = torch.cos(angle)
+ zero_padding = torch.zeros_like(cos_)
+ one_padding = torch.ones_like(cos_)
+ col_a = torch.stack(
+ [cos_, zero_padding, sin_], dim=0
+ )
+ col_b = torch.stack(
+ [zero_padding, one_padding, zero_padding], dim=0
+ )
+ col_c = torch.stack(
+ [-1. * sin_, zero_padding, cos_], dim=0
+ )
+ rot_mtx = torch.stack(
+ [col_a, col_b, col_c], dim=-1
+ )
+ return rot_mtx
+
+def plane_rotation_matrix_from_angle(angle):
+ ## angle of
+ sin_ = torch.sin(angle)
+ cos_ = torch.cos(angle)
+ col_a = torch.stack(
+ [cos_, sin_], dim=0 ### col of the rotation matrix
+ )
+ col_b = torch.stack(
+ [-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
+ )
+ rot_mtx = torch.stack(
+ [col_a, col_b], dim=-1 ### rotation matrix
+ )
+ return rot_mtx
+
+def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
+ u_xx = u_x * u_x
+ u_yy = u_y * u_y
+ u_zz = u_z * u_z
+ u_xy = u_x * u_y
+ u_xz = u_x * u_z
+ u_yz = u_y * u_z
+
+ row_a = torch.stack(
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
+ )
+ # print(f"row_a: {row_a.size()}")
+ row_b = torch.stack(
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
+ )
+ # print(f"row_b: {row_b.size()}")
+ row_c = torch.stack(
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
+ )
+ # print(f"row_c: {row_c.size()}")
+
+ ### rot_mtx for the rot_mtx ###
+ rot_mtx = torch.stack(
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
+ )
+
+ return rot_mtx
+
+
+def update_quaternion(delta_angle, prev_quat):
+ s1 = 0
+ s2 = prev_quat[0]
+ v2 = prev_quat[1:]
+ v1 = delta_angle / 2
+ new_v = s1 * v2 + s2 * v1 + torch.cross(v1, v2)
+ new_s = s1 * s2 - torch.sum(v1 * v2)
+ new_quat = torch.cat([new_s.unsqueeze(0), new_v], dim=0)
+ return new_quat
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1) # -1 for the quaternion matrix #
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+
+
+class Inertial:
+ def __init__(self, origin_rpy, origin_xyz, mass, inertia) -> None:
+ self.origin_rpy = origin_rpy
+ self.origin_xyz = origin_xyz
+ self.mass = mass
+ self.inertia = inertia
+ if torch.sum(self.inertia).item() < 1e-4:
+ self.inertia = self.inertia + torch.eye(3, dtype=torch.float32).cuda()
+ pass
+
+class Visual:
+ def __init__(self, visual_xyz, visual_rpy, geometry_mesh_fn, geometry_mesh_scale) -> None:
+ # self.visual_origin = visual_origin
+ self.visual_xyz = visual_xyz
+ self.visual_rpy = visual_rpy
+ self.mesh_nm = geometry_mesh_fn.split("/")[-1].split(".")[0]
+ mesh_root = "/home/xueyi/diffsim/NeuS/rsc/mano"
+ if not os.path.exists(mesh_root):
+ mesh_root = "/data/xueyi/diffsim/NeuS/rsc/mano"
+ if "shadow" in urdf_fn and "left" in urdf_fn:
+ mesh_root = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description_left"
+ if not os.path.exists(mesh_root):
+ mesh_root = "/root/diffsim/quasi-dyn/rsc/shadow_hand_description_left"
+ elif "shadow" in urdf_fn:
+ mesh_root = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description"
+ if not os.path.exists(mesh_root):
+ mesh_root = "/root/diffsim/quasi-dyn/rsc/shadow_hand_description"
+ elif "redmax" in urdf_fn:
+ mesh_root = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand"
+ if not os.path.exists(mesh_root):
+ mesh_root = "/root/diffsim/quasi-dyn/rsc/redmax_hand"
+
+ self.mesh_root = mesh_root
+ geometry_mesh_fn = geometry_mesh_fn.replace(".dae", ".obj")
+ self.geometry_mesh_fn = os.path.join(mesh_root, geometry_mesh_fn)
+
+ self.geometry_mesh_scale = geometry_mesh_scale
+ # tranformed by xyz #
+ self.vertices, self.faces = self.load_geoemtry_mesh()
+ self.cur_expanded_visual_pts = None
+ pass
+
+ def load_geoemtry_mesh(self, ):
+ # mesh_root =
+ # if self.geometry_mesh_fn.end
+ mesh = trimesh.load_mesh(self.geometry_mesh_fn)
+ vertices = mesh.vertices
+ faces = mesh.faces
+
+ vertices = torch.from_numpy(vertices).float().cuda()
+ faces =torch.from_numpy(faces).long().cuda()
+
+ vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
+
+ return vertices, faces
+
+ # init_visual_meshes = get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes)
+ def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes, expanded_pts=False):
+ # cur_vertices = torch.matmul(parent_rot, self.vertices.transpose(1, 0)).contiguous().transpose(1, 0).contiguous() + parent_trans.unsqueeze(0)
+
+ if not expanded_pts:
+ cur_vertices = self.vertices
+ # print(f"adding mesh loaded from {self.geometry_mesh_fn}")
+ init_visual_meshes['vertices'].append(cur_vertices) # cur vertices # trans #
+ init_visual_meshes['faces'].append(self.faces)
+ else:
+ ## expanded visual meshes ##
+ cur_vertices = self.cur_expanded_visual_pts
+ init_visual_meshes['vertices'].append(cur_vertices)
+ init_visual_meshes['faces'].append(self.faces)
+ return init_visual_meshes
+
+ def expand_visual_pts(self, ):
+ # expand_factor = 0.2
+ # nn_expand_pts = 20
+
+ # expand_factor = 0.4
+ # nn_expand_pts = 40 ### number of the expanded points ### ## points ##
+
+ # expand_factor = 0.2
+ # nn_expand_pts = 20 ##
+
+ expand_factor = 0.1
+ nn_expand_pts = 10 ##
+ expand_save_fn = f"{self.mesh_nm}_expanded_pts_factor_{expand_factor}_nnexp_{nn_expand_pts}_new.npy"
+ expand_save_fn = os.path.join(self.mesh_root, expand_save_fn) #
+
+ if not os.path.exists(expand_save_fn):
+ cur_expanded_visual_pts = []
+ if self.cur_expanded_visual_pts is None:
+ cur_src_pts = self.vertices
+ else:
+ cur_src_pts = self.cur_expanded_visual_pts
+ maxx_verts, _ = torch.max(cur_src_pts, dim=0)
+ minn_verts, _ = torch.min(cur_src_pts, dim=0)
+ extent_verts = maxx_verts - minn_verts ## (3,)-dim vecotr
+ norm_extent_verts = torch.norm(extent_verts, dim=-1).item() ## (1,)-dim vector
+ expand_r = norm_extent_verts * expand_factor
+ # nn_expand_pts = 5 # expand the vertices to 5 times of the original vertices
+ for i_pts in range(self.vertices.size(0)):
+ cur_pts = cur_src_pts[i_pts]
+ # sample from the circile with cur_pts as thejcenter and the radius as expand_r
+ # (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
+ offset_dist = Uniform(-1. * expand_r, expand_r)
+ offset_vec = offset_dist.sample((nn_expand_pts, 3)).cuda()
+ cur_expanded_pts = cur_pts + offset_vec
+ cur_expanded_visual_pts.append(cur_expanded_pts)
+ cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
+ np.save(expand_save_fn, cur_expanded_visual_pts.detach().cpu().numpy())
+ else:
+ print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
+ cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
+ cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float().cuda()
+ self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
+ return self.cur_expanded_visual_pts
+
+## epand
+## link urdf ## expand the visual pts to form the expanded visual grids pts #
+# use get_name_to_visual_pts_faces to get the transformed visual pts and faces #
+class Link_urdf:
+ def __init__(self, name, inertial: Inertial, visual: Visual=None) -> None:
+
+ self.name = name
+ self.inertial = inertial
+ self.visual = visual # vsiual meshes #
+
+ # self.joint = joint
+ # self.body = body
+ # self.children = children
+ # self.name = name
+
+ self.link_idx = ...
+
+ # self.args = args
+
+ self.joint = None # joint name to struct
+ # self.join
+ self.children = ...
+ self.children = {} # joint name to child sruct
+
+ def expand_visual_pts(self, expanded_visual_pts, link_name_to_visited, link_name_to_link_struct):
+ link_name_to_visited[self.name] = 1
+ if self.visual is not None:
+ cur_expanded_visual_pts = self.visual.expand_visual_pts()
+ expanded_visual_pts.append(cur_expanded_visual_pts)
+
+ for cur_link in self.children:
+ cur_link_struct = link_name_to_link_struct[self.children[cur_link]]
+ cur_link_name = cur_link_struct.name
+ if cur_link_name in link_name_to_visited:
+ continue
+ ## expanded visual pts for the expand visual ptsS ##
+ ## link name to visited ##
+ expanded_visual_pts = cur_link_struct.expand_visual_pts(expanded_visual_pts, link_name_to_visited, link_name_to_link_struct)
+ return expanded_visual_pts
+
+ def set_initial_state(self, states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
+
+ link_name_to_visited[self.name] = 1
+
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+ cur_joint = self.joint[cur_joint_name]
+ cur_joint_name = cur_joint.name
+ cur_child = self.children[cur_joint_name]
+ cur_child_struct = link_name_to_link_struct[cur_child]
+ cur_child_name = cur_child_struct.name
+
+ if cur_child_name in link_name_to_visited:
+ continue
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] # action joint name to joint idx #
+ # cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] #
+ # cur_joint = self.joint[cur_joint_name]
+ cur_state = states[cur_joint_idx] ### joint state ###
+ cur_joint.set_initial_state(cur_state)
+ cur_child_struct.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
+
+
+
+ def set_penetration_forces(self, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
+ link_name_to_visited[self.name] = 1
+
+ # the current joint of the # update state #
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+
+ cur_joint = self.joint[cur_joint_name] # joint model
+
+ cur_child = self.children[cur_joint_name] # child model #
+
+ cur_child_struct = link_name_to_link_struct[cur_child]
+
+ cur_child_name = cur_child_struct.name
+
+ cur_child_link_idx = cur_child_struct.link_idx
+
+ if cur_child_name in link_name_to_visited:
+ continue
+
+ try:
+ cur_child_inertia = cur_child_struct.cur_inertia
+ except:
+ cur_child_inertia = torch.eye(3, dtype=torch.float32).cuda()
+
+
+ if cur_joint.type in ['revolute'] and (cur_joint_name not in ['WRJ2', 'WRJ1']):
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ # cur_action = actions[cur_joint_idx]
+ ### get the child struct ###
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+ # set actions and update states #
+ cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state(n_grad=True)
+ cur_joint_tot_rot = torch.matmul(parent_rot, cur_joint_rot) ## R_p (R_j p + t_j) + t_p
+ cur_joint_tot_trans = torch.matmul(parent_rot, cur_joint_trans.unsqueeze(-1)).squeeze(-1) + parent_trans
+
+ # cur_joint.set_actions_and_update_states_v2(cur_action, cur_timestep, time_cons, cur_child_inertia.detach(), parent_rot, parent_trans + cur_joint.origin_xyz, penetration_forces=penetration_forces, link_idx=cur_child_link_idx)
+
+ # cur_timestep, time_cons, cur_inertia, cur_joint_tot_rot=None, cur_joint_tot_trans=None, penetration_forces=None, sampled_visual_pts_joint_idxes=None, joint_idx=None
+
+
+ cur_joint.set_penetration_forces(cur_child_inertia.detach(), cur_joint_tot_rot, cur_joint_tot_trans, link_idx=cur_child_link_idx, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes, joint_idx=cur_joint_idx - 2, joint_penetration_forces=joint_penetration_forces)
+ else:
+ cur_joint_tot_rot = parent_rot
+ cur_joint_tot_trans = parent_trans
+
+
+ cur_child_struct.set_penetration_forces(action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot=cur_joint_tot_rot, parent_trans=cur_joint_tot_trans, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes, joint_penetration_forces=joint_penetration_forces)
+
+
+
+
+
+ def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes, link_name_to_link_struct, link_name_to_visited, expanded_pts=False, joint_idxes=None, state_vals=None):
+ link_name_to_visited[self.name] = 1
+
+ # 'transformed_joint_pos': [], 'link_idxes': []
+ if self.joint is not None: # get init visual meshes #
+ # for i_ch, (cur_joint, cur_child) in enumerate(zip(self.joint, self.children)):
+ # print(f"joint: {cur_joint.name}, child: {cur_child.name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
+ # joint_origin_xyz = cur_joint.origin_xyz
+ # init_visual_meshes = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, init_visual_meshes)
+ # print(f"name: {self.name}, keys: {self.joint.keys()}")
+ for cur_joint_name in self.joint: #
+ cur_joint = self.joint[cur_joint_name]
+
+ # if state_vals is not None:
+ # cur_joint_idx = cur_joint.joint_idx
+ # state_vals[cur_joint_idx] = cur_joint.state.detach().cpu().numpy()
+
+ cur_child_name = self.children[cur_joint_name]
+ cur_child = link_name_to_link_struct[cur_child_name]
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
+ joint_origin_xyz = cur_joint.origin_xyz
+ if cur_child_name in link_name_to_visited:
+ continue
+ cur_child_visual_pts = {'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []}
+
+ # joint idxes #
+ cur_child_visual_pts, joint_idxes = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, cur_child_visual_pts, link_name_to_link_struct, link_name_to_visited, expanded_pts=expanded_pts, joint_idxes=joint_idxes)
+
+ cur_child_verts, cur_child_faces = cur_child_visual_pts['vertices'], cur_child_visual_pts['faces']
+ cur_child_link_idxes = cur_child_visual_pts['link_idxes']
+ cur_transformed_joint_pos = cur_child_visual_pts['transformed_joint_pos']
+ joint_link_idxes = cur_child_visual_pts['joint_link_idxes']
+
+ if len(cur_child_verts) > 0:
+ cur_child_verts, cur_child_faces = merge_meshes(cur_child_verts, cur_child_faces)
+ cur_child_verts = cur_child_verts + cur_joint.origin_xyz.unsqueeze(0)
+ cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state()
+ cur_child_verts = torch.matmul(cur_joint_rot, cur_child_verts.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
+
+ if len(cur_transformed_joint_pos) > 0:
+ cur_transformed_joint_pos = torch.cat(cur_transformed_joint_pos, dim=0)
+ cur_transformed_joint_pos = cur_transformed_joint_pos + cur_joint.origin_xyz.unsqueeze(0)
+ cur_transformed_joint_pos = torch.matmul(cur_joint_rot, cur_transformed_joint_pos.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
+ cur_joint_pos = cur_joint_trans.unsqueeze(0).clone()
+ cur_transformed_joint_pos = torch.cat(
+ [cur_transformed_joint_pos, cur_joint_pos], dim=0 ##### joint poses #####
+ )
+ else:
+ cur_transformed_joint_pos = cur_joint_trans.unsqueeze(0).clone()
+
+ if len(joint_link_idxes) > 0:
+ joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
+ cur_joint_idx = cur_child.link_idx
+ joint_link_idxes = torch.cat(
+ [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
+ )
+ else:
+ joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
+
+ # joint link idxes #
+
+ # cur_child_verts = cur_child_verts + # transformed joint pos #
+ cur_child_link_idxes = torch.cat(cur_child_link_idxes, dim=-1)
+ # joint_link_idxes = torch.cat(joint_link_idxes, dim=-1)
+ init_visual_meshes['vertices'].append(cur_child_verts)
+ init_visual_meshes['faces'].append(cur_child_faces)
+ init_visual_meshes['link_idxes'].append(cur_child_link_idxes)
+ init_visual_meshes['transformed_joint_pos'].append(cur_transformed_joint_pos)
+ init_visual_meshes['joint_link_idxes'].append(joint_link_idxes)
+
+ # joint_origin_xyz = self.joint.origin_xyz # c ## get forces from the expanded point set ##
+ else:
+ joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
+ # self.parent_rot_mtx = parent_rot
+ # self.parent_trans_vec = parent_trans + joint_origin_xyz
+
+
+ if self.visual is not None:
+ # ## get init visual meshes ## ## --
+ init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes, expanded_pts=expanded_pts)
+ cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
+ cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long).cuda()+ self.link_idx
+ init_visual_meshes['link_idxes'].append(cur_link_idxes)
+
+ # self.link_idx #
+ if joint_idxes is not None:
+ cur_idxes = [self.link_idx for _ in range(cur_visual_mesh_pts_nn)]
+ cur_idxes = torch.tensor(cur_idxes, dtype=torch.long).cuda()
+ joint_idxes.append(cur_idxes)
+
+
+
+ # for cur_link in self.children: #
+ # init_visual_meshes = cur_link.get_init_visual_meshes(self.parent_rot_mtx, self.parent_trans_vec, init_visual_meshes)
+ return init_visual_meshes, joint_idxes ## init visual meshes ##
+
+ # calculate inerti
+ def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
+ link_name_to_visited[self.name] = 1
+ self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32).cuda()
+
+ if self.joint is not None:
+ for joint_nm in self.joint:
+ cur_joint = self.joint[joint_nm]
+ cur_child = self.children[joint_nm]
+ cur_child_struct = link_name_to_link_struct[cur_child]
+ cur_child_name = cur_child_struct.name
+ if cur_child_name in link_name_to_visited:
+ continue
+ joint_rot, joint_trans = cur_joint.compute_transformation_from_current_state(n_grad=True)
+ # cur_parent_rot = torch.matmul(parent_rot, joint_rot) #
+ # cur_parent_trans = torch.matmul(parent_rot, joint_trans.unsqueeze(-1)).squeeze(-1) + parent_trans #
+ child_inertia = cur_child_struct.calculate_inertia(link_name_to_visited, link_name_to_link_struct)
+ child_inertia = torch.matmul(
+ joint_rot.detach(), torch.matmul(child_inertia, joint_rot.detach().transpose(1, 0).contiguous())
+ ).detach()
+ self.cur_inertia += child_inertia
+ # if self.visual is not None:
+ # self.cur_inertia += self.visual.inertia
+ self.cur_inertia += self.inertial.inertia.detach()
+ return self.cur_inertia
+
+
+ def set_delta_state_and_update(self, states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct):
+
+ link_name_to_visited[self.name] = 1
+
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+
+ cur_joint = self.joint[cur_joint_name] # joint model
+
+ cur_child = self.children[cur_joint_name] # child model #
+
+ cur_child_struct = link_name_to_link_struct[cur_child]
+
+ cur_child_name = cur_child_struct.name
+
+ if cur_child_name in link_name_to_visited:
+ continue
+
+ ## cur child inertia ##
+ # cur_child_inertia = cur_child_struct.cur_inertia
+
+
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ cur_state = states[cur_joint_idx]
+ ### get the child struct ###
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+ # set actions and update states #
+ cur_joint.set_delta_state_and_update(cur_state, cur_timestep)
+
+ cur_child_struct.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct)
+
+ def set_delta_state_and_update_v2(self, states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct):
+ link_name_to_visited[self.name] = 1
+
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+ cur_joint = self.joint[cur_joint_name] # joint model
+ cur_child = self.children[cur_joint_name] # child model #
+ cur_child_struct = link_name_to_link_struct[cur_child]
+ cur_child_name = cur_child_struct.name
+ if cur_child_name in link_name_to_visited:
+ continue
+ # cur_child_inertia = cur_child_struct.cur_inertia
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ cur_state = states[cur_joint_idx]
+ ### get the child struct ###
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+ # set actions and update states #
+ cur_joint.set_delta_state_and_update_v2(cur_state, cur_timestep)
+ cur_child_struct.set_delta_state_and_update_v2(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct)
+
+ # get_joint_state(self, cur_ts, state_vals):
+ def get_joint_state(self, cur_ts, state_vals, link_name_to_visited, link_name_to_link_struct, action_joint_name_to_joint_idx):
+ link_name_to_visited[self.name] = 1
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+
+ cur_joint = self.joint[cur_joint_name] # joint model
+
+ cur_child = self.children[cur_joint_name] # child model #
+
+ cur_child_struct = link_name_to_link_struct[cur_child]
+
+ cur_child_name = cur_child_struct.name
+
+ if cur_child_name in link_name_to_visited:
+ continue
+
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ state_vals[cur_joint_idx] = cur_joint.timestep_to_states[cur_ts + 1] # .state.detach().cpu().numpy()
+ # state_vals = cur_joint.get_joint_state(cur_ts, state_vals)
+
+ state_vals = cur_child_struct.get_joint_state(cur_ts, state_vals, link_name_to_visited, link_name_to_link_struct, action_joint_name_to_joint_idx)
+ return state_vals
+
+ # the joint #
+ # set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
+ def set_actions_and_update_states(self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
+
+ link_name_to_visited[self.name] = 1
+
+ # the current joint of the # update state #
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+
+ cur_joint = self.joint[cur_joint_name] # joint model
+
+ cur_child = self.children[cur_joint_name] # child model #
+
+ cur_child_struct = link_name_to_link_struct[cur_child]
+
+ cur_child_name = cur_child_struct.name
+
+ if cur_child_name in link_name_to_visited:
+ continue
+
+ cur_child_inertia = cur_child_struct.cur_inertia
+
+
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ cur_action = actions[cur_joint_idx]
+ ### get the child struct ###
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+ # set actions and update states #
+ cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia.detach())
+
+ cur_child_struct.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
+
+
+ def set_actions_and_update_states_v2(self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
+
+ link_name_to_visited[self.name] = 1
+
+ # the current joint of the # update state #
+ if self.joint is not None:
+ for cur_joint_name in self.joint:
+
+ cur_joint = self.joint[cur_joint_name] # joint model
+
+ cur_child = self.children[cur_joint_name] # child model #
+
+ cur_child_struct = link_name_to_link_struct[cur_child]
+
+ cur_child_name = cur_child_struct.name
+
+ cur_child_link_idx = cur_child_struct.link_idx
+
+ if cur_child_name in link_name_to_visited:
+ continue
+
+ try:
+ cur_child_inertia = cur_child_struct.cur_inertia
+ except:
+ cur_child_inertia = torch.eye(3, dtype=torch.float32).cuda()
+
+
+ if cur_joint.type in ['revolute']:
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
+ cur_action = actions[cur_joint_idx]
+ ### get the child struct ###
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+ # set actions and update states #
+ cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state(n_grad=True)
+ cur_joint_tot_rot = torch.matmul(parent_rot, cur_joint_rot) ## R_p (R_j p + t_j) + t_p
+ cur_joint_tot_trans = torch.matmul(parent_rot, cur_joint_trans.unsqueeze(-1)).squeeze(-1) + parent_trans
+
+ # cur_joint.set_actions_and_update_states_v2(cur_action, cur_timestep, time_cons, cur_child_inertia.detach(), parent_rot, parent_trans + cur_joint.origin_xyz, penetration_forces=penetration_forces, link_idx=cur_child_link_idx)
+
+ cur_joint.set_actions_and_update_states_v2(cur_action, cur_timestep, time_cons, cur_child_inertia.detach(), cur_joint_tot_rot, cur_joint_tot_trans, penetration_forces=penetration_forces, link_idx=cur_child_link_idx, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes)
+ else:
+ cur_joint_tot_rot = parent_rot
+ cur_joint_tot_trans = parent_trans
+
+
+ cur_child_struct.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot=cur_joint_tot_rot, parent_trans=cur_joint_tot_trans, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes)
+
+
+ def set_init_states_target_value(self, init_states):
+ if self.joint.type == 'revolute':
+ self.joint_angle = init_states[self.joint.joint_idx]
+ joint_axis = self.joint.axis
+ self.rot_vec = self.joint_angle * joint_axis
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
+ self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
+ self.joint.timestep_to_states[0] = self.joint.state.detach()
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
+ for cur_link in self.children:
+ cur_link.set_init_states_target_value(init_states)
+
+ # should forward for one single step -> use the action #
+ def set_init_states(self, ):
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
+ self.joint.timestep_to_states[0] = self.joint.state.detach()
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
+ for cur_link in self.children:
+ cur_link.set_init_states()
+
+
+ def get_visual_pts(self, visual_pts_list):
+ visual_pts_list = self.body.get_visual_pts(visual_pts_list)
+ for cur_link in self.children:
+ visual_pts_list = cur_link.get_visual_pts(visual_pts_list)
+ visual_pts_list = torch.cat(visual_pts_list, dim=0)
+ return visual_pts_list
+
+ def get_visual_faces_list(self, visual_faces_list):
+ visual_faces_list = self.body.get_visual_faces_list(visual_faces_list)
+ for cur_link in self.children:
+ visual_faces_list = cur_link.get_visual_faces_list(visual_faces_list)
+ return visual_faces_list
+ # pass
+
+
+ def set_state(self, name_to_state):
+ self.joint.set_state(name_to_state=name_to_state)
+ for child_link in self.children:
+ child_link.set_state(name_to_state)
+
+ def set_state_via_vec(self, state_vec):
+ self.joint.set_state_via_vec(state_vec)
+ for child_link in self.children:
+ child_link.set_state_via_vec(state_vec)
+
+
+
+
+class Joint_Limit:
+ def __init__(self, effort, lower, upper, velocity) -> None:
+ self.effort = effort
+ self.lower = lower
+ self.velocity = velocity
+ self.upper = upper
+ pass
+
+# Joint_urdf(name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit)
+class Joint_urdf: #
+
+ def __init__(self, name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit, origin_xyz_string="") -> None:
+ self.name = name
+ self.type = joint_type
+ self.parent_link = parent_link
+ self.child_link = child_link
+ self.origin_xyz = origin_xyz
+ self.axis_xyz = axis_xyz
+ self.limit = limit
+
+ self.origin_xyz_string = origin_xyz_string
+
+ # joint angle; joint state #
+ self.timestep_to_vels = {}
+ self.timestep_to_states = {}
+
+ self.init_pos = self.origin_xyz.clone()
+
+ #### only for the current state #### # joint urdf #
+ self.state = nn.Parameter(
+ torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
+ )
+ self.action = nn.Parameter(
+ torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
+ )
+ # self.rot_mtx = np.eye(3, dtypes=np.float32)
+ # self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
+ self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
+ self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
+
+ def set_initial_state(self, state):
+ # joint angle as the state value #
+ self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
+ delta_rot_vec = self.axis_xyz * state
+ # self.timestep_to_states[0] = state.detach()
+ cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
+ init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
+ self.timestep_to_states[0] = init_state.detach()
+ self.state = init_state
+
+ def set_delta_state_and_update(self, state, cur_timestep):
+ self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
+ delta_rot_vec = self.axis_xyz * state
+ if cur_timestep == 0:
+ prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
+ else:
+ # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
+ prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
+ cur_state = prev_state + update_quaternion(delta_rot_vec, prev_state)
+ self.timestep_to_states[cur_timestep] = cur_state.detach()
+ self.state = cur_state
+
+
+ def set_delta_state_and_update_v2(self, delta_state, cur_timestep):
+ self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
+
+ if cur_timestep == 0:
+ cur_state = delta_state
+ else:
+ # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
+ # prev_state = self.timestep_to_states[cur_timestep - 1]
+ cur_state = self.timestep_to_states[cur_timestep - 1].detach() + delta_state
+ ## cur_state ## #
+ self.timestep_to_states[cur_timestep] = cur_state # .detach()
+
+
+ # delta_rot_vec = self.axis_xyz * state #
+
+ cur_rot_vec = self.axis_xyz * cur_state ### cur_state #### #
+ # angle to the quaternion ? #
+ init_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
+ cur_quat_state = init_state + update_quaternion(cur_rot_vec, init_state)
+ self.state = cur_quat_state
+
+ # if cur_timestep == 0:
+ # prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
+ # else:
+ # # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
+ # prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
+ # cur_state = prev_state + update_quaternion(delta_rot_vec, prev_state)
+ # self.timestep_to_states[cur_timestep] = cur_state.detach()
+ # self.state = cur_state
+
+
+ def compute_transformation_from_current_state(self, n_grad=False):
+ # together with the parent rot mtx and the parent trans vec #
+ # cur_joint_state = self.state
+ if self.type == "revolute":
+ # rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
+ # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
+ if n_grad:
+ rot_mtx = quaternion_to_matrix(self.state.detach())
+ else:
+ rot_mtx = quaternion_to_matrix(self.state)
+ # trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
+ trans_vec = self.origin_xyz - torch.matmul(rot_mtx, self.origin_xyz.view(3, 1)).view(3).contiguous()
+ self.rot_mtx = rot_mtx
+ self.trans_vec = trans_vec
+ elif self.type == "fixed":
+ rot_mtx = torch.eye(3, dtype=torch.float32).cuda()
+ trans_vec = torch.zeros((3,), dtype=torch.float32).cuda()
+ # trans_vec = self.origin_xyz
+ self.rot_mtx = rot_mtx
+ self.trans_vec = trans_vec #
+ else:
+ pass
+ return self.rot_mtx, self.trans_vec
+
+
+ # set actions # set actions and udpate states #
+ def set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
+
+ # timestep_to_vels, timestep_to_states, state #
+ if self.type in ['revolute']:
+
+ self.action = action
+ #
+ # visual_pts and visual_pts_mass #
+ # cur_joint_pos = self.joint.pos #
+ # TODO: check whether the following is correct # # set
+ torque = self.action * self.axis_xyz
+
+ # # Compute inertia matrix #
+ # inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
+ # for i_pts in range(self.visual_pts.size(0)):
+ # cur_pts = self.visual_pts[i_pts]
+ # cur_pts_mass = self.visual_pts_mass[i_pts]
+ # cur_r = cur_pts - cur_joint_pos # r_i
+ # # cur_vert = init_passive_mesh[i_v]
+ # # cur_r = cur_vert - init_passive_mesh_center
+ # dot_r_r = torch.sum(cur_r * cur_r)
+ # cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
+ # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
+ # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
+ # m = torch.sum(self.visual_pts_mass)
+ # # Use torque to update angular velocity -> state #
+ # inertia_inv = torch.linalg.inv(inertial)
+
+ # axis-angle of
+ # inertia_inv = self.cur_inertia_inv
+ # print(f"updating actions and states for the joint {self.name} with type {self.type}")
+ inertia_inv = torch.linalg.inv(cur_inertia).detach()
+
+ delta_omega = torch.matmul(inertia_inv, torque.unsqueeze(-1)).squeeze(-1)
+
+ # delta_omega = torque / 400 # # axis_xyz #
+
+ ## actions -> with the dynamic information -> time cons -> angular acc -> delta angular vel -> delta angle
+ # TODO: dt should be an optim#izable constant? should it be the same value as that optimized for the passive object? #
+ delta_angular_vel = delta_omega * time_cons # * self.args.dt
+ delta_angular_vel = delta_angular_vel.squeeze(0)
+ if cur_timestep > 0: ## cur_timestep - 1 ##
+ prev_angular_vel = self.timestep_to_vels[cur_timestep - 1].detach()
+ # cur_angular_vel = prev_angular_vel + delta_angular_vel * DAMPING
+ cur_angular_vel = prev_angular_vel * DAMPING + delta_angular_vel # p
+ else:
+ cur_angular_vel = delta_angular_vel # angular vel #
+
+ self.timestep_to_vels[cur_timestep] = cur_angular_vel.detach()
+
+ cur_delta_quat = cur_angular_vel * time_cons # * self.args.dt
+ cur_delta_quat = cur_delta_quat.squeeze(0) # delta quat #
+ cur_state = self.timestep_to_states[cur_timestep].detach() # quaternion #
+ # print(f"cur_delta_quat: {cur_delta_quat.size()}, cur_state: {cur_state.size()}")
+ nex_state = cur_state + update_quaternion(cur_delta_quat, cur_state)
+ self.timestep_to_states[cur_timestep + 1] = nex_state.detach()
+ self.state = nex_state # set the joint state #
+
+
+ def set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia, cur_joint_tot_rot=None, cur_joint_tot_trans=None, penetration_forces=None, link_idx=None, sampled_visual_pts_joint_idxes=None):
+
+ # timestep_to_vels, timestep_to_states, state #
+ if self.type in ['revolute']:
+
+ self.action = action ## strategy 2
+ #
+ # visual_pts and visual_pts_mass #
+ # cur_joint_pos = self.joint.pos #
+ # TODO: check whether the following is correct # # set
+
+ if penetration_forces is not None:
+ penetration_forces_values = penetration_forces['penetration_forces'].detach()
+ penetration_forces_points = penetration_forces['penetration_forces_points'].detach()
+
+ ####### use a part of peentration points and forces #######
+ if sampled_visual_pts_joint_idxes is not None:
+ selected_forces_mask = sampled_visual_pts_joint_idxes == link_idx ## select the current link's penetrated points
+ else:
+ selected_forces_mask = torch.ones_like(penetration_forces_values[:, 0]).bool()
+ ####### use a part of peentration points and forces #######
+
+ if torch.sum(selected_forces_mask.float()) > 0.5: ## has penetrated points in this link ##
+
+ penetration_forces_values = penetration_forces_values[selected_forces_mask]
+ penetration_forces_points = penetration_forces_points[selected_forces_mask]
+ # tot_rot_mtx, tot_trans_vec
+ # cur_joint_rot = self.tot_rot_mtx
+ # cur_joint_trans = self.tot_trans_vec
+ cur_joint_rot = cur_joint_tot_rot.detach()
+ cur_joint_trans = cur_joint_tot_trans.detach() ## total rot; total trans ##
+ local_frame_penetration_forces_values = torch.matmul(cur_joint_rot.transpose(1, 0), penetration_forces_values.transpose(1, 0)).transpose(1, 0)
+ local_frame_penetration_forces_points = torch.matmul(cur_joint_rot.transpose(1, 0), (penetration_forces_points - cur_joint_trans.unsqueeze(0)).transpose(1, 0)).transpose(1, 0)
+
+ joint_pos_to_forces_points = local_frame_penetration_forces_points - self.axis_xyz.unsqueeze(0)
+ forces_torques = torch.cross(joint_pos_to_forces_points, local_frame_penetration_forces_values) # forces values of the local frame #
+ forces_torques = torch.sum(forces_torques, dim=0)
+
+ forces_torques_dot_axis = torch.sum(self.axis_xyz * forces_torques)
+ penetration_delta_state = forces_torques_dot_axis
+ else:
+ penetration_delta_state = 0.0
+ else:
+ penetration_delta_state = 0.0
+
+
+ torque = self.action * self.axis_xyz
+
+ # # Compute inertia matrix #
+ # inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
+ # for i_pts in range(self.visual_pts.size(0)):
+ # cur_pts = self.visual_pts[i_pts]
+ # cur_pts_mass = self.visual_pts_mass[i_pts]
+ # cur_r = cur_pts - cur_joint_pos # r_i
+ # # cur_vert = init_passive_mesh[i_v]
+ # # cur_r = cur_vert - init_passive_mesh_center
+ # dot_r_r = torch.sum(cur_r * cur_r)
+ # cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
+ # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
+ # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
+ # m = torch.sum(self.visual_pts_mass)
+ # # Use torque to update angular velocity -> state #
+ # inertia_inv = torch.linalg.inv(inertial)
+
+ # axis-angle of
+ # inertia_inv = self.cur_inertia_inv
+ # print(f"updating actions and states for the joint {self.name} with type {self.type}")
+
+
+ # inertia_inv = torch.linalg.inv(cur_inertia).detach()
+
+ inertia_inv = torch.eye(n=3, dtype=torch.float32).cuda()
+
+
+
+ delta_omega = torch.matmul(inertia_inv, torque.unsqueeze(-1)).squeeze(-1)
+
+ # delta_omega = torque / 400
+
+
+ # TODO: dt should be an optim#izable constant? should it be the same value as that optimized for the passive object? #
+ delta_angular_vel = delta_omega * time_cons # * self.args.dt
+ delta_angular_vel = delta_angular_vel.squeeze(0)
+ if cur_timestep > 0: ## cur_timestep - 1 ##
+ prev_angular_vel = self.timestep_to_vels[cur_timestep - 1].detach()
+ # cur_angular_vel = prev_angular_vel + delta_angular_vel * DAMPING
+ cur_angular_vel = prev_angular_vel * DAMPING + delta_angular_vel # p
+ # cur_angular_vel = prev_angular_vel + delta_angular_vel # p
+ else:
+ cur_angular_vel = delta_angular_vel # angular vel #
+
+ self.timestep_to_vels[cur_timestep] = cur_angular_vel.detach()
+
+ cur_delta_angle = cur_angular_vel * time_cons # * self.args.dt
+ # cur_delta_quat = cur_delta_angle.squeeze(0) # delta quat #
+ # cur_state = self.timestep_to_states[cur_timestep].detach() # quaternion #
+ # # print(f"cur_delta_quat: {cur_delta_quat.size()}, cur_state: {cur_state.size()}")
+ # nex_state = cur_state + update_quaternion(cur_delta_quat, cur_state)
+
+ ### strategy 2 ###
+ dot_cur_delta_angle_w_axis = torch.sum( ## delta angle with axises ##
+ cur_delta_angle * self.axis_xyz, dim=-1
+ )
+ ## dot cur deltawith the
+ delta_state = dot_cur_delta_angle_w_axis ## delta angle w axieses ##
+
+ # if cur_timestep
+ if cur_timestep == 0:
+ self.timestep_to_states[cur_timestep] = torch.zeros((1,), dtype=torch.float32).cuda()
+ cur_state = self.timestep_to_states[cur_timestep].detach()
+ nex_state = cur_state + delta_state
+ # nex_state = nex_state + penetration_delta_state
+ ## state rot vector along axis ## ## get the pentrated froces -- calulaterot qj
+ state_rot_vec_along_axis = nex_state * self.axis_xyz
+ ### state in the rotation vector -> state in quaternion ###
+ state_rot_quat = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda() + update_quaternion(state_rot_vec_along_axis, torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda())
+ ### state
+ self.state = state_rot_quat
+ ### get states? ##
+ self.timestep_to_states[cur_timestep + 1] = nex_state # .detach()
+ # self.state = nex_state # set the joint state #
+
+
+
+ def set_penetration_forces(self, cur_inertia, cur_joint_tot_rot=None, cur_joint_tot_trans=None, link_idx=None, penetration_forces=None, sampled_visual_pts_joint_idxes=None, joint_idx=None, joint_penetration_forces=None):
+
+ # timestep_to_vels, timestep_to_states, state #
+ if self.type in ['revolute'] :
+
+ # self.action = action ## strategy 2
+ #
+ # visual_pts and visual_pts_mass #
+ # cur_joint_pos = self.joint.pos #
+ # TODO: check whether the following is correct # # set
+
+ if penetration_forces is not None:
+ penetration_forces_values = penetration_forces['penetration_forces'].detach()
+ penetration_forces_points = penetration_forces['penetration_forces_points'].detach()
+
+ ####### use a part of peentration points and forces #######
+ if sampled_visual_pts_joint_idxes is not None:
+ selected_forces_mask = sampled_visual_pts_joint_idxes == link_idx ## select the current link's penetrated points
+ else:
+ selected_forces_mask = torch.ones_like(penetration_forces_values[:, 0]).bool()
+ ####### use a part of peentration points and forces #######
+
+ if torch.sum(selected_forces_mask.float()) > 0.5: ## has penetrated points in this link ##
+
+ penetration_forces_values = penetration_forces_values[selected_forces_mask]
+ penetration_forces_points = penetration_forces_points[selected_forces_mask]
+ # tot_rot_mtx, tot_trans_vec
+ # cur_joint_rot = self.tot_rot_mtx
+ # cur_joint_trans = self.tot_trans_vec
+ cur_joint_rot = cur_joint_tot_rot.detach()
+ cur_joint_trans = cur_joint_tot_trans.detach() ## total rot; total trans ##
+ local_frame_penetration_forces_values = torch.matmul(cur_joint_rot.transpose(1, 0), penetration_forces_values.transpose(1, 0)).transpose(1, 0)
+ local_frame_penetration_forces_points = torch.matmul(cur_joint_rot.transpose(1, 0), (penetration_forces_points - cur_joint_trans.unsqueeze(0)).transpose(1, 0)).transpose(1, 0)
+
+ joint_pos_to_forces_points = local_frame_penetration_forces_points - self.axis_xyz.unsqueeze(0)
+ forces_torques = torch.cross(joint_pos_to_forces_points, local_frame_penetration_forces_values) # forces values of the local frame #
+ forces_torques = torch.sum(forces_torques, dim=0)
+
+ forces = torch.sum(local_frame_penetration_forces_values, dim=0)
+
+ cur_joint_maximal_forces = torch.cat(
+ [forces, forces_torques], dim=0
+ )
+ cur_joint_idx = joint_idx
+ joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
+
+ # forces_torques_dot_axis = torch.sum(self.axis_xyz * forces_torques)
+ # penetration_delta_state = forces_torques_dot_axis
+ else:
+ penetration_delta_state = 0.0
+ cur_joint_maximal_forces = torch.zeros((6,), dtype=torch.float32).cuda()
+ cur_joint_idx = joint_idx
+ joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
+
+ else:
+ penetration_delta_state = 0.0
+ cur_joint_idx = joint_idx
+ joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
+
+
+
+
+
+
+ def get_joint_state(self, cur_ts, state_vals):
+ cur_joint_state = self.timestep_to_states[cur_ts + 1]
+ state_vals[self.joint_idx] = cur_joint_state
+ return state_vals
+
+
+class Robot_urdf:
+ def __init__(self, links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx, tot_joints=None, real_actions_joint_name_to_joint_idx=None) -> None:
+ self.links = links
+ self.link_name_to_link_idxes = link_name_to_link_idxes
+ self.link_name_to_link_struct = link_name_to_link_struct
+
+ # joint_name_to_joint_idx, actions_joint_name_to_joint_idx
+ self.joint_name_to_joint_idx = joint_name_to_joint_idx
+ self.actions_joint_name_to_joint_idx = actions_joint_name_to_joint_idx
+
+ self.tot_joints = tot_joints
+ # #
+ # #
+ self.act_joint_idxes = list(self.actions_joint_name_to_joint_idx.values())
+ self.act_joint_idxes = sorted(self.act_joint_idxes, reverse=False)
+ self.act_joint_idxes = torch.tensor(self.act_joint_idxes, dtype=torch.long).cuda()[2:]
+
+ self.real_actions_joint_name_to_joint_idx = real_actions_joint_name_to_joint_idx
+
+
+ self.init_vertices, self.init_faces = self.get_init_visual_pts()
+
+ joint_name_to_joint_idx_sv_fn = "mano_joint_name_to_joint_idx.npy"
+ np.save(joint_name_to_joint_idx_sv_fn, self.joint_name_to_joint_idx)
+
+ actions_joint_name_to_joint_idx_sv_fn = "mano_actions_joint_name_to_joint_idx.npy"
+ np.save(actions_joint_name_to_joint_idx_sv_fn, self.actions_joint_name_to_joint_idx)
+
+ tot_joints = len(self.joint_name_to_joint_idx)
+ tot_actions_joints = len(self.actions_joint_name_to_joint_idx)
+
+ print(f"tot_joints: {tot_joints}, tot_actions_joints: {tot_actions_joints}")
+
+ pass
+
+ # robot.expande
+ def expand_visual_pts(self, ):
+ link_name_to_visited = {}
+ # transform the visual pts #
+ # action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+ expanded_visual_pts = []
+ # expanded the visual pts # # transformed viusal pts # or the translations of the visual pts #
+ expanded_visual_pts = palm_link.expand_visual_pts(expanded_visual_pts, link_name_to_visited, self.link_name_to_link_struct)
+ expanded_visual_pts = torch.cat(expanded_visual_pts, dim=0)
+ # pass
+ return expanded_visual_pts
+
+
+ ### samping issue? --- TODO` `
+ def get_init_visual_pts(self, expanded_pts=False, joint_idxes=None):
+ init_visual_meshes = {
+ 'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
+ }
+ init_parent_rot = torch.eye(3, dtype=torch.float32).cuda()
+ init_parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ ### from the palm linke ##
+ init_visual_meshes, joint_idxes = palm_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited, expanded_pts=expanded_pts, joint_idxes=joint_idxes)
+
+ self.link_idxes = torch.cat(init_visual_meshes['link_idxes'], dim=-1)
+ self.transformed_joint_pos = torch.cat(init_visual_meshes['transformed_joint_pos'], dim=0)
+ self.joint_link_idxes = torch.cat(init_visual_meshes['joint_link_idxes'], dim=-1) ###
+
+
+ if joint_idxes is not None:
+ joint_idxes = torch.cat(joint_idxes, dim=0)
+
+ # for cur_link in self.links:
+ # init_visual_meshes = cur_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited)
+
+ init_vertices, init_faces = merge_meshes(init_visual_meshes['vertices'], init_visual_meshes['faces'])
+
+ if joint_idxes is not None:
+ return init_vertices, init_faces, joint_idxes
+ else:
+ return init_vertices, init_faces
+
+ def set_penetration_forces(self, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
+ # print(f"action_joint_name_to_joint_idx: {action_joint_name_to_joint_idx}")
+
+ parent_rot = torch.eye(3, dtype=torch.float32).cuda()
+ parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
+
+ # cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
+
+ palm_link.set_penetration_forces(action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces)
+
+ def set_delta_state_and_update(self, states, cur_timestep):
+ link_name_to_visited = {}
+
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ palm_link.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, self.link_name_to_link_struct)
+
+
+ def set_delta_state_and_update_v2(self, states, cur_timestep, use_real_act_joint=False):
+ link_name_to_visited = {}
+
+ if use_real_act_joint:
+ action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
+ else:
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ palm_link.set_delta_state_and_update_v2(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, self.link_name_to_link_struct)
+
+
+
+ # cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia)
+ def set_actions_and_update_states(self, actions, cur_timestep, time_cons,):
+ # self.actions_joint_name_to_joint_idx as the action joint name to joint idx
+ link_name_to_visited = {}
+ ## to joint idx ##
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ ## set actions ##
+ palm_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
+
+ # for cur_joint in
+
+ # for cur_link in self.links:
+ # if cur_link.joint is not None:
+ # for cur_joint_nm in cur_link.joint:
+ # if cur_link.joint[cur_joint_nm].type in ['revolute']:
+ # cur_link_joint_name = cur_link.joint[cur_joint_nm].name
+ # cur_link_joint_idx = self.actions_joint_name_to_joint_idx[cur_link_joint_name]
+
+
+ # for cur_link in self.links:
+ # cur_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
+
+
+
+ def get_joint_state(self, cur_ts, state_vals):
+ # link_name_to_visited = {}
+ ## to joint idx ##
+ # action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ # parent_rot = torch.eye(3, dtype=torch.float32).cuda()
+ # parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
+ ## set actions ## #
+ # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
+ # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
+ state_vals = palm_link.get_joint_state(cur_ts, state_vals, link_name_to_visited, self.link_name_to_link_struct, self.actions_joint_name_to_joint_idx)
+ return state_vals
+
+
+ def set_actions_and_update_states_v2(self, actions, cur_timestep, time_cons, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
+ # self.actions_joint_name_to_joint_idx as the action joint name to joint idx
+ link_name_to_visited = {}
+ ## to joint idx ##
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ parent_rot = torch.eye(3, dtype=torch.float32).cuda()
+ parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
+ ## set actions ## #
+ # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
+ # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
+ palm_link.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct, parent_rot=parent_rot, parent_trans=parent_trans, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes)
+
+
+ ### TODO: add the contact torque when calculating the nextstep states ###
+ ### TODO: not an accurate implementation since differen joints should be considered ###
+ ### TODO: the articulated force modle is not so easy as this one .... ###
+ def set_contact_forces(self, hard_selected_forces, hard_selected_manipulating_points, hard_selected_sampled_input_pts_idxes):
+ # transformed_joint_pos, joint_link_idxes, link_idxes #
+ selected_pts_link_idxes = self.link_idxes[hard_selected_sampled_input_pts_idxes]
+ # use the selected link idxes #
+ # selected pts idxes #
+
+ # self.joint_link_idxes, transformed_joint_pos #
+ self.link_idx_to_transformed_joint_pos = {}
+ for i_link in range(self.transformed_joint_pos.size(0)):
+ cur_link_idx = self.link_idxes[i_link].item()
+ cur_link_pos = self.transformed_joint_pos[i_link]
+ # if cur_link_idx not in self.link_idx_to_transformed_joint_pos:
+ self.link_idx_to_transformed_joint_pos[cur_link_idx] = cur_link_pos
+ # self.link_idx_to_transformed_joint_pos[cur_link_idx].append(cur_link_pos)
+
+ # from the
+ self.link_idx_to_contact_forces = {}
+ for i_c_pts in range(hard_selected_forces.size(0)):
+ cur_contact_force = hard_selected_forces[i_c_pts] ##
+ cur_link_idx = selected_pts_link_idxes[i_c_pts].item()
+ cur_link_pos = self.link_idx_to_transformed_joint_pos[cur_link_idx]
+ cur_link_action_pos = hard_selected_manipulating_points[i_c_pts]
+ # (action_pos - link_pos) x (-contact_force) #
+ cur_contact_torque = torch.cross(
+ cur_link_action_pos - cur_link_pos, -cur_contact_force
+ )
+ if cur_link_idx not in self.link_idx_to_contact_forces:
+ self.link_idx_to_contact_forces[cur_link_idx] = [cur_contact_torque]
+ else:
+ self.link_idx_to_contact_forces[cur_link_idx].append(cur_contact_torque)
+ for link_idx in self.link_idx_to_contact_forces:
+ self.link_idx_to_contact_forces[link_idx] = torch.stack(self.link_idx_to_contact_forces[link_idx], dim=0)
+ self.link_idx_to_contact_forces[link_idx] = torch.sum(self.link_idx_to_contact_forces[link_idx] , dim=0)
+ for link_idx, link_struct in enumerate(self.links):
+ if link_idx in self.link_idx_to_contact_forces:
+ cur_link_contact_force = self.link_idx_to_contact_forces[link_idx]
+ link_struct.contact_torque = cur_link_contact_force
+ else:
+ link_struct.contact_torque = None
+
+
+ # def se ### from the optimizable initial states ###
+ def set_initial_state(self, states):
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
+ link_name_to_visited = {}
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ palm_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
+
+ # for cur_link in self.links:
+ # cur_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
+
+ ### after each timestep -> re-calculate the inertial matrix using the current simulated states and the set the new actiosn and forward the simulation #
+ def calculate_inertia(self):
+ link_name_to_visited = {}
+
+ palm_idx = self.link_name_to_link_idxes["palm"]
+ palm_link = self.links[palm_idx]
+
+ link_name_to_visited = {}
+
+ palm_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
+
+ # for cur_link in self.links:
+ # cur_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
+
+ ###
+
+
+
+
+def parse_nparray_from_string(strr, args=None):
+ vals = strr.split(" ")
+ vals = [float(val) for val in vals]
+ vals = np.array(vals, dtype=np.float32)
+ vals = torch.from_numpy(vals).float()
+ ## vals ##
+ vals = nn.Parameter(vals.cuda(), requires_grad=True)
+
+ return vals
+
+
+### parse link data ###
+def parse_link_data(link, args):
+
+ link_name = link.attrib["name"]
+ # print(f"parsing link: {link_name}") ## joints body meshes #
+
+ joint = link.find("./joint")
+
+ joint_name = joint.attrib["name"]
+ joint_type = joint.attrib["type"]
+ if joint_type in ["revolute"]: ## a general xml parser here?
+ axis = joint.attrib["axis"]
+ axis = parse_nparray_from_string(axis, args=args)
+ else:
+ axis = None
+ pos = joint.attrib["pos"] #
+ pos = parse_nparray_from_string(pos, args=args)
+ quat = joint.attrib["quat"]
+ quat = parse_nparray_from_string(quat, args=args)
+
+ try:
+ frame = joint.attrib["frame"]
+ except:
+ frame = "WORLD"
+
+ if joint_type not in ["fixed"]:
+ damping = joint.attrib["damping"]
+ damping = float(damping)
+ else:
+ damping = 0.0
+
+ cur_joint = Joint(joint_name, joint_type, axis, pos, quat, frame, damping, args=args)
+
+ body = link.find("./body")
+ body_name = body.attrib["name"]
+ body_type = body.attrib["type"]
+ if body_type == "mesh":
+ filename = body.attrib["filename"]
+ else:
+ filename = ""
+
+ if body_type == "sphere":
+ radius = body.attrib["radius"]
+ radius = float(radius)
+ else:
+ radius = 0.
+
+ pos = body.attrib["pos"]
+ pos = parse_nparray_from_string(pos, args=args)
+ quat = body.attrib["quat"]
+ quat = joint.attrib["quat"]
+ try:
+ transform_type = body.attrib["transform_type"]
+ except:
+ transform_type = "OBJ_TO_WORLD"
+ density = body.attrib["density"]
+ density = float(density)
+ mu = body.attrib["mu"]
+ mu = float(mu)
+ try: ## rgba ##
+ rgba = body.attrib["rgba"]
+ rgba = parse_nparray_from_string(rgba, args=args)
+ except:
+ rgba = np.zeros((4,), dtype=np.float32)
+
+ cur_body = Body(body_name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args=args)
+
+ children_link = []
+ links = link.findall("./link")
+ for child_link in links: #
+ cur_child_link = parse_link_data(child_link, args=args)
+ children_link.append(cur_child_link)
+
+ link_name = link.attrib["name"]
+ link_obj = Link(link_name, joint=cur_joint, body=cur_body, children=children_link, args=args)
+ return link_obj
+
+
+### parse link data ###
+def parse_link_data_urdf(link):
+
+ link_name = link.attrib["name"]
+ # print(f"parsing link: {link_name}") ## joints body meshes #
+
+ inertial = link.find("./inertial")
+
+ origin = inertial.find("./origin")
+
+ if origin is not None:
+ inertial_pos = origin.attrib["xyz"]
+ try:
+ inertial_rpy = origin.attrib["rpy"]
+ except:
+ inertial_rpy = "0.0 0.0 0.0"
+ else:
+ inertial_pos = "0.0 0.0 0.0"
+ inertial_rpy = "0.0 0.0 0.0"
+ inertial_pos = parse_nparray_from_string(inertial_pos)
+
+ inertial_rpy = parse_nparray_from_string(inertial_rpy)
+
+ inertial_mass = inertial.find("./mass")
+ inertial_mass = inertial_mass.attrib["value"]
+
+ inertial_inertia = inertial.find("./inertia")
+ inertial_ixx = inertial_inertia.attrib["ixx"]
+ inertial_ixx = float(inertial_ixx)
+ inertial_ixy = inertial_inertia.attrib["ixy"]
+ inertial_ixy = float(inertial_ixy)
+ inertial_ixz = inertial_inertia.attrib["ixz"]
+ inertial_ixz = float(inertial_ixz)
+ inertial_iyy = inertial_inertia.attrib["iyy"]
+ inertial_iyy = float(inertial_iyy)
+ inertial_iyz = inertial_inertia.attrib["iyz"]
+ inertial_iyz = float(inertial_iyz)
+ inertial_izz = inertial_inertia.attrib["izz"]
+ inertial_izz = float(inertial_izz)
+
+ inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32).cuda()
+ inertial_inertia_mtx[0, 0] = inertial_ixx
+ inertial_inertia_mtx[0, 1] = inertial_ixy
+ inertial_inertia_mtx[0, 2] = inertial_ixz
+ inertial_inertia_mtx[1, 0] = inertial_ixy
+ inertial_inertia_mtx[1, 1] = inertial_iyy
+ inertial_inertia_mtx[1, 2] = inertial_iyz
+ inertial_inertia_mtx[2, 0] = inertial_ixz
+ inertial_inertia_mtx[2, 1] = inertial_iyz
+ inertial_inertia_mtx[2, 2] = inertial_izz
+
+ # [xx, xy, xz] #
+ # [0, yy, yz] #
+ # [0, 0, zz] #
+
+ # a strange inertia value ... #
+ # TODO: how to compute the inertia matrix? #
+
+ visual = link.find("./visual")
+
+ if visual is not None:
+ origin = visual.find("./origin")
+ visual_pos = origin.attrib["xyz"]
+ visual_pos = parse_nparray_from_string(visual_pos)
+ visual_rpy = origin.attrib["rpy"]
+ visual_rpy = parse_nparray_from_string(visual_rpy)
+ geometry = visual.find("./geometry")
+ geometry_mesh = geometry.find("./mesh")
+ if geometry_mesh is None:
+ visual = None
+ else:
+ mesh_fn = geometry_mesh.attrib["filename"]
+
+ try:
+ mesh_scale = geometry_mesh.attrib["scale"]
+ except:
+ mesh_scale = "1 1 1"
+
+ mesh_scale = parse_nparray_from_string(mesh_scale)
+ mesh_fn = str(mesh_fn)
+
+
+ link_struct = Link_urdf(name=link_name, inertial=Inertial(origin_rpy=inertial_rpy, origin_xyz=inertial_pos, mass=inertial_mass, inertia=inertial_inertia_mtx), visual=Visual(visual_rpy=visual_rpy, visual_xyz=visual_pos, geometry_mesh_fn=mesh_fn, geometry_mesh_scale=mesh_scale) if visual is not None else None)
+
+ return link_struct
+
+def parse_joint_data_urdf(joint):
+ joint_name = joint.attrib["name"]
+ joint_type = joint.attrib["type"]
+
+ parent = joint.find("./parent")
+ child = joint.find("./child")
+ parent_name = parent.attrib["link"]
+ child_name = child.attrib["link"]
+
+ joint_origin = joint.find("./origin")
+ # if joint_origin.
+ try:
+ origin_xyz_string = joint_origin.attrib["xyz"]
+ origin_xyz = parse_nparray_from_string(origin_xyz_string)
+ except:
+ origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
+ origin_xyz_string = ""
+
+ joint_axis = joint.find("./axis")
+ if joint_axis is not None:
+ joint_axis = joint_axis.attrib["xyz"]
+ joint_axis = parse_nparray_from_string(joint_axis)
+ else:
+ joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32).cuda()
+
+ joint_limit = joint.find("./limit")
+ if joint_limit is not None:
+ joint_lower = joint_limit.attrib["lower"]
+ joint_lower = float(joint_lower)
+ joint_upper = joint_limit.attrib["upper"]
+ joint_upper = float(joint_upper)
+ joint_effort = joint_limit.attrib["effort"]
+ joint_effort = float(joint_effort)
+ if "velocity" in joint_limit.attrib:
+ joint_velocity = joint_limit.attrib["velocity"]
+ joint_velocity = float(joint_velocity)
+ else:
+ joint_velocity = 0.5
+ else:
+ joint_lower = -0.5000
+ joint_upper = 1.57
+ joint_effort = 1000
+ joint_velocity = 0.5
+
+ # cosntruct the joint data #
+ joint_limit = Joint_Limit(effort=joint_effort, lower=joint_lower, upper=joint_upper, velocity=joint_velocity)
+ cur_joint_struct = Joint_urdf(joint_name, joint_type, parent_name, child_name, origin_xyz, joint_axis, joint_limit, origin_xyz_string)
+ return cur_joint_struct
+
+
+
+def parse_data_from_urdf(xml_fn):
+
+ tree = ElementTree()
+ tree.parse(xml_fn)
+ print(f"{xml_fn}")
+ ### get total robots ###
+ # robots = tree.findall("link")
+ cur_robot = tree
+ # i_robot = 0
+ # tot_robots = []
+ # for cur_robot in robots:
+ # print(f"Getting robot: {i_robot}")
+ # i_robot += 1
+ # print(f"len(robots): {len(robots)}")
+ # cur_robot = robots[0]
+ cur_links = cur_robot.findall("./link")
+ # curlinks
+ # i_link = 0
+ link_name_to_link_idxes = {}
+ cur_robot_links = []
+ link_name_to_link_struct = {}
+ for i_link_idx, cur_link in enumerate(cur_links):
+ cur_link_struct = parse_link_data_urdf(cur_link)
+ print(f"Adding link {cur_link_struct.name}, link_idx: {i_link_idx}")
+ cur_link_struct.link_idx = i_link_idx
+ cur_robot_links.append(cur_link_struct)
+
+ link_name_to_link_idxes[cur_link_struct.name] = i_link_idx
+ link_name_to_link_struct[cur_link_struct.name] = cur_link_struct
+ # for cur_link in cur_links:
+ # cur_robot_links.append(parse_link_data_urdf(cur_link, args=args))
+
+ print(f"link_name_to_link_struct: {len(link_name_to_link_struct)}, ")
+
+ tot_robot_joints = []
+
+ joint_name_to_joint_idx = {}
+
+ actions_joint_name_to_joint_idx = {}
+
+ cur_joints = cur_robot.findall("./joint")
+
+ real_actions_joint_name_to_joint_idx = {}
+
+ act_joint_idx = 0
+ for i_joint, cur_joint in enumerate(cur_joints):
+ cur_joint_struct = parse_joint_data_urdf(cur_joint)
+ cur_joint_parent_link = cur_joint_struct.parent_link
+ cur_joint_child_link = cur_joint_struct.child_link
+
+ cur_joint_idx = len(tot_robot_joints)
+ cur_joint_name = cur_joint_struct.name
+
+ joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
+
+ print(f"cur_joint_name: {cur_joint_name}, cur_joint_idx: {cur_joint_idx}, axis: {cur_joint_struct.axis_xyz}, origin: {cur_joint_struct.origin_xyz}")
+
+ cur_joint_type = cur_joint_struct.type
+ if cur_joint_type in ['revolute']:
+ actions_joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
+ # actions_joint_name_to_joint_idx[cur_joint_name] = act_joint_idx
+ # act_joint_idx = act_joint_idx + 1
+
+ real_actions_joint_name_to_joint_idx[cur_joint_name] = act_joint_idx
+ act_joint_idx = act_joint_idx + 1
+
+
+ #### add the current joint to tot joints ###
+ tot_robot_joints.append(cur_joint_struct)
+
+ parent_link_idx = link_name_to_link_idxes[cur_joint_parent_link]
+ cur_parent_link_struct = cur_robot_links[parent_link_idx]
+
+
+ child_link_idx = link_name_to_link_idxes[cur_joint_child_link]
+ cur_child_link_struct = cur_robot_links[child_link_idx]
+ # parent link struct #
+ if link_name_to_link_struct[cur_joint_parent_link].joint is not None:
+ link_name_to_link_struct[cur_joint_parent_link].joint[cur_joint_struct.name] = cur_joint_struct
+ link_name_to_link_struct[cur_joint_parent_link].children[cur_joint_struct.name] = cur_child_link_struct.name
+ # cur_child_link_struct
+ # cur_parent_link_struct.joint.append(cur_joint_struct)
+ # cur_parent_link_struct.children.append(cur_child_link_struct)
+ else:
+ link_name_to_link_struct[cur_joint_parent_link].joint = {
+ cur_joint_struct.name: cur_joint_struct
+ }
+ link_name_to_link_struct[cur_joint_parent_link].children = {
+ cur_joint_struct.name: cur_child_link_struct.name
+ # cur_child_link_struct
+ }
+ # cur_parent_link_struct.joint = [cur_joint_struct]
+ # cur_parent_link_struct.children.append(cur_child_link_struct)
+ # pass
+
+ print(f"actions_joint_name_to_joint_idx: {len(actions_joint_name_to_joint_idx)}")
+ print(f"real_actions_joint_name_to_joint_idx: {len(real_actions_joint_name_to_joint_idx)}")
+ cur_robot_obj = Robot_urdf(cur_robot_links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx, tot_robot_joints, real_actions_joint_name_to_joint_idx=real_actions_joint_name_to_joint_idx)
+ # tot_robots.append(cur_robot_obj)
+
+ print(f"Actions joint idxes:")
+ print(list(actions_joint_name_to_joint_idx.keys()))
+
+ actions_joint_idxes = list(actions_joint_name_to_joint_idx.values())
+ actions_joint_idxes = sorted(actions_joint_idxes)
+ print(f"joint indexes: {actions_joint_idxes}")
+
+ # for the joint robots #
+ # for every joint
+ # tot_actuators = []
+ # actuators = tree.findall("./actuator/motor")
+ # joint_nm_to_joint_idx = {}
+ # i_act = 0
+ # for cur_act in actuators:
+ # cur_act_joint_nm = cur_act.attrib["joint"]
+ # joint_nm_to_joint_idx[cur_act_joint_nm] = i_act
+ # i_act += 1 ### add the act ###
+
+ # tot_robots[0].set_joint_idx(joint_nm_to_joint_idx) ### set joint idx here ### # tot robots #
+ # tot_robots[0].get_nn_pts()
+ # tot_robots[1].get_nn_pts()
+
+ return cur_robot_obj
+
+
+def get_name_to_state_from_str(states_str):
+ tot_states = states_str.split(" ")
+ tot_states = [float(cur_state) for cur_state in tot_states]
+ joint_name_to_state = {}
+ for i in range(len(tot_states)):
+ cur_joint_name = f"joint{i + 1}"
+ cur_joint_state = tot_states[i]
+ joint_name_to_state[cur_joint_name] = cur_joint_state
+ return joint_name_to_state
+
+
+def merge_meshes(verts_list, faces_list):
+ nn_verts = 0
+ tot_verts_list = []
+ tot_faces_list = []
+ for i_vv, cur_verts in enumerate(verts_list):
+ cur_verts_nn = cur_verts.size(0)
+ tot_verts_list.append(cur_verts)
+ tot_faces_list.append(faces_list[i_vv] + nn_verts)
+ nn_verts = nn_verts + cur_verts_nn
+ tot_verts_list = torch.cat(tot_verts_list, dim=0)
+ tot_faces_list = torch.cat(tot_faces_list, dim=0)
+ return tot_verts_list, tot_faces_list
+
+
+### get init s
+class RobotAgent: # robot and the robot #
+ def __init__(self, xml_fn, args=None) -> None:
+ global urdf_fn
+ urdf_fn = xml_fn
+ self.xml_fn = xml_fn
+ # self.args = args
+
+ ##
+ active_robot = parse_data_from_urdf(xml_fn)
+
+ self.time_constant = nn.Embedding(
+ num_embeddings=3, embedding_dim=1
+ ).cuda()
+ torch.nn.init.ones_(self.time_constant.weight) #
+ self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
+
+ self.optimizable_actions = nn.Embedding(
+ num_embeddings=100, embedding_dim=60,
+ ).cuda()
+ torch.nn.init.zeros_(self.optimizable_actions.weight) #
+
+ self.learning_rate = 5e-4
+
+ self.active_robot = active_robot
+
+
+ self.set_init_states()
+ init_visual_pts = self.get_init_state_visual_pts()
+ self.init_visual_pts = init_visual_pts
+
+ cur_verts, cur_faces = self.active_robot.get_init_visual_pts()
+ self.robot_pts = cur_verts
+ self.robot_faces = cur_faces
+
+
+ def set_init_states_target_value(self, init_states):
+ # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
+ # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
+
+ # tot_init_states = {}
+ # tot_init_states['glb_rot'] = glb_rot;
+ # tot_init_states['glb_trans'] = glb_trans;
+ # tot_init_states['links_init_states'] = init_states
+ # self.active_robot.set_init_states_target_value(tot_init_states)
+ # init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
+ self.active_robot.set_initial_state(init_states)
+
+ def set_init_states(self):
+ # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
+ # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
+
+ # ### random rotation ###
+ # # glb_rot_np = R.random().as_matrix()
+ # # glb_rot = torch.from_numpy(glb_rot_np).float().cuda()
+ # ### random rotation ###
+
+ # # glb_rot, glb_trans #
+ # init_states = {}
+ # init_states['glb_rot'] = glb_rot;
+ # init_states['glb_trans'] = glb_trans;
+ # self.active_robot.set_init_states(init_states)
+
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
+ self.active_robot.set_initial_state(init_joint_states)
+
+ # cur_verts, joint_idxes = get_init_state_visual_pts(expanded_pts=False, ret_joint_idxes=True)
+ def get_init_state_visual_pts(self, expanded_pts=False, ret_joint_idxes=False):
+ # visual_pts_list = [] # compute the transformation via current state #
+ # visual_pts_list, visual_pts_mass_list = self.active_robot.compute_transformation_via_current_state( visual_pts_list)
+
+ if ret_joint_idxes:
+ joint_idxes = []
+ cur_verts, cur_faces, joint_idxes = self.active_robot.get_init_visual_pts(expanded_pts=expanded_pts, joint_idxes=joint_idxes)
+ else:
+ cur_verts, cur_faces = self.active_robot.get_init_visual_pts(expanded_pts=expanded_pts, joint_idxes=None)
+ self.faces = cur_faces
+ # joint_idxes = torch.cat()
+ # self.robot_pts = cur_verts
+ # self.robot_faces = cur_faces
+ # init_visual_pts = visual_pts_list
+ if ret_joint_idxes:
+ return cur_verts, joint_idxes
+ else:
+ return cur_verts
+
+ def set_actions_and_update_states(self, actions, cur_timestep):
+ #
+ time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
+ self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
+ pass
+
+ def set_actions_and_update_states_v2(self, actions, cur_timestep, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
+ #
+ time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
+ self.active_robot.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes) ###
+ pass
+
+ # state_vals = self.robot_agent.get_joint_state( cur_ts, state_vals, link_name_to_link_struct)
+ def get_joint_state(self, cur_ts, state_vals):
+ state_vals = self.active_robot.get_joint_state(cur_ts, state_vals)
+ return state_vals
+
+ def forward_stepping_test(self, ):
+ # delta_glb_rot; delta_glb_trans #
+ timestep_to_visual_pts = {}
+ for i_step in range(50):
+ actions = {}
+ actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
+ actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
+ actions_link_actions = torch.ones((22, ), dtype=torch.float32).cuda()
+ # actions_link_actions = actions_link_actions * 0.2
+ actions_link_actions = actions_link_actions * -1. #
+ actions['link_actions'] = actions_link_actions
+ self.set_actions_and_update_states(actions=actions, cur_timestep=i_step)
+
+ cur_visual_pts = robot_agent.get_init_state_visual_pts()
+ cur_visual_pts = cur_visual_pts.detach().cpu().numpy()
+ timestep_to_visual_pts[i_step + 1] = cur_visual_pts
+ return timestep_to_visual_pts
+
+ def initialize_optimization(self, reference_pts_dict):
+ self.n_timesteps = 50
+ # self.n_timesteps = 19 # first 19-timesteps optimization #
+ self.nn_tot_optimization_iters = 100
+ # self.nn_tot_optimization_iters = 57
+ # TODO: load reference points #
+ self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
+ self.ts_to_reference_pts = {
+ ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float().cuda() for ts in self.ts_to_reference_pts
+ }
+
+
+ def forward_stepping_optimization(self, ):
+ nn_tot_optimization_iters = self.nn_tot_optimization_iters
+ params_to_train = []
+ params_to_train += list(self.optimizable_actions.parameters())
+ self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
+
+ for i_iter in range(nn_tot_optimization_iters):
+
+ tot_losses = []
+ ts_to_robot_points = {}
+ for cur_ts in range(self.n_timesteps):
+ # print(f"iter: {i_iter}, cur_ts: {cur_ts}")
+ # actions = {}
+ # actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
+ # actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
+ actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long).cuda() + cur_ts).squeeze(0)
+ # actions_link_actions = actions_link_actions * 0.2
+ # actions_link_actions = actions_link_actions * -1. #
+ # actions['link_actions'] = actions_link_actions
+ # self.set_actions_and_update_states(actions=actions, cur_timestep=cur_ts) # update the interaction #
+
+ with torch.no_grad():
+ self.active_robot.calculate_inertia()
+
+ self.active_robot.set_actions_and_update_states(actions_link_actions, cur_ts, 0.2)
+
+ cur_visual_pts, cur_faces = self.active_robot.get_init_visual_pts()
+ ts_to_robot_points[cur_ts + 1] = cur_visual_pts.clone()
+
+ cur_reference_pts = self.ts_to_reference_pts[cur_ts + 1]
+ diff = torch.sum((cur_visual_pts - cur_reference_pts) ** 2, dim=-1)
+ diff = diff.mean()
+
+ # diff.
+ self.optimizer.zero_grad()
+ diff.backward(retain_graph=True)
+ # diff.backward(retain_graph=False)
+ self.optimizer.step()
+
+ tot_losses.append(diff.item())
+
+
+ loss = sum(tot_losses) / float(len(tot_losses))
+ print(f"Iter: {i_iter}, average loss: {loss}")
+ # print(f"Iter: {i_iter}, average loss: {loss.item()}, start optimizing")
+ # self.optimizer.zero_grad()
+ # loss.backward()
+ # self.optimizer.step()
+
+ self.ts_to_robot_points = {
+ ts: ts_to_robot_points[ts].detach().cpu().numpy() for ts in ts_to_robot_points
+ }
+ self.ts_to_ref_points = {
+ ts: self.ts_to_reference_pts[ts].detach().cpu().numpy() for ts in ts_to_robot_points
+ }
+ return self.ts_to_robot_points, self.ts_to_ref_points
+
+
+
+
+def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
+ u_xx = u_x * u_x
+ u_yy = u_y * u_y
+ u_zz = u_z * u_z
+ u_xy = u_x * u_y
+ u_xz = u_x * u_z
+ u_yz = u_y * u_z ##
+
+
+ row_a = torch.stack(
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
+ )
+ # print(f"row_a: {row_a.size()}")
+ row_b = torch.stack(
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
+ )
+ # print(f"row_b: {row_b.size()}")
+ row_c = torch.stack(
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
+ )
+ # print(f"row_c: {row_c.size()}")
+
+ ### rot_mtx for the rot_mtx ###
+ rot_mtx = torch.stack(
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
+ )
+
+ return rot_mtx
+
+
+def calibreate_urdf_files(urdf_fn):
+ # active_robot = parse_data_from_urdf(xml_fn)
+ active_robot = parse_data_from_urdf(urdf_fn)
+ tot_joints = active_robot.tot_joints
+
+ # class Joint_urdf: #
+ # def __init__(self, name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit) -> None:
+ # self.name = name
+ # self.type = joint_type
+ # self.parent_link = parent_link
+ # self.child_link = child_link
+ # self.origin_xyz = origin_xyz
+ # self.axis_xyz = axis_xyz
+ # self.limit = limit
+
+ with open(urdf_fn) as rf:
+ urdf_string = rf.read()
+ for cur_joint in tot_joints:
+ print(f"type: {cur_joint.type}, origin: {cur_joint.origin_xyz}")
+ cur_joint_origin = cur_joint.origin_xyz
+ scaled_joint_origin = cur_joint_origin * 3.
+ cur_joint_origin_string = cur_joint.origin_xyz_string
+ if len(cur_joint_origin_string) == 0 or torch.sum(cur_joint_origin).item() == 0.:
+ continue
+ #
+ cur_joint_origin_string_wtag = ""
+ scaled_joint_origin_string_wtag = ""
+ # scaled_joint_origin_string = f"{scaled_joint_origin[0].item()} {scaled_joint_origin[1].item()} {scaled_joint_origin[2].item()}"
+ # urdf_string = urdf_string.replace(cur_joint_origin_string, scaled_joint_origin_string)
+ urdf_string = urdf_string.replace(cur_joint_origin_string_wtag, scaled_joint_origin_string_wtag)
+ changed_urdf_fn = urdf_fn.replace(".urdf", "_scaled.urdf")
+ with open(changed_urdf_fn, "w") as wf:
+ wf.write(urdf_string)
+ print(f"changed_urdf_fn: {changed_urdf_fn}")
+ # exit(0)
+
+
+def get_GT_states_data_from_ckpt(ckpt_fn):
+ mano_nn_substeps = 1
+ num_steps = 60
+ mano_robot_actions = nn.Embedding(
+ num_embeddings=num_steps * mano_nn_substeps, embedding_dim=60,
+ )
+ torch.nn.init.zeros_(mano_robot_actions.weight)
+ # params_to_train += list(self.robot_actions.parameters())
+
+ mano_robot_delta_states = nn.Embedding(
+ num_embeddings=num_steps * mano_nn_substeps, embedding_dim=60,
+ )
+ torch.nn.init.zeros_(mano_robot_delta_states.weight)
+ # params_to_train += list(self.robot_delta_states.parameters())
+
+ mano_robot_init_states = nn.Embedding(
+ num_embeddings=1, embedding_dim=60,
+ )
+ torch.nn.init.zeros_(mano_robot_init_states.weight)
+ # params_to_train += list(self.robot_init_states.parameters())
+
+ mano_robot_glb_rotation = nn.Embedding(
+ num_embeddings=num_steps * mano_nn_substeps, embedding_dim=4
+ )
+ mano_robot_glb_rotation.weight.data[:, 0] = 1.
+ mano_robot_glb_rotation.weight.data[:, 1:] = 0.
+ # params_to_train += list(self.robot_glb_rotation.parameters())
+
+
+ mano_robot_glb_trans = nn.Embedding(
+ num_embeddings=num_steps * mano_nn_substeps, embedding_dim=3
+ )
+ torch.nn.init.zeros_(mano_robot_glb_trans.weight)
+ # params_to_train += list(self.robot_glb_trans.parameters())
+
+ mano_robot_states = nn.Embedding(
+ num_embeddings=num_steps * mano_nn_substeps, embedding_dim=60,
+ )
+ torch.nn.init.zeros_(mano_robot_states.weight)
+ mano_robot_states.weight.data[0, :] = mano_robot_init_states.weight.data[0, :].clone()
+
+
+ ''' Load optimized MANO hand actions and states '''
+ # ### laod optimized init actions #### #
+ # if 'model.load_optimized_init_actions' in self.conf and len(self.conf['model.load_optimized_init_actions']) > 0:
+ # print(f"[MANO] Loading optimized init transformations from {self.conf['model.load_optimized_init_actions']}")
+ cur_optimized_init_actions_fn = ckpt_fn
+ optimized_init_actions_ckpt = torch.load(cur_optimized_init_actions_fn, map_location='cpu', )
+
+ if 'mano_robot_states' in optimized_init_actions_ckpt:
+ mano_robot_states.load_state_dict(optimized_init_actions_ckpt['mano_robot_states'])
+
+ if 'mano_robot_init_states' in optimized_init_actions_ckpt:
+ mano_robot_init_states.load_state_dict(optimized_init_actions_ckpt['mano_robot_init_states'])
+
+ if 'mano_robot_glb_rotation' in optimized_init_actions_ckpt:
+ mano_robot_glb_rotation.load_state_dict(optimized_init_actions_ckpt['mano_robot_glb_rotation'])
+
+ if 'mano_robot_glb_trans' in optimized_init_actions_ckpt: # mano_robot_glb_trans
+ mano_robot_glb_trans.load_state_dict(optimized_init_actions_ckpt['mano_robot_glb_trans'])
+
+ mano_glb_trans_np_data = mano_robot_glb_trans.weight.data.detach().cpu().numpy()
+ mano_glb_rotation_np_data = mano_robot_glb_rotation.weight.data.detach().cpu().numpy()
+ mano_states_np_data = mano_robot_states.weight.data.detach().cpu().numpy()
+
+ if optimized_init_actions_ckpt is not None and 'object_transl' in optimized_init_actions_ckpt:
+ object_transl = optimized_init_actions_ckpt['object_transl'].detach().cpu().numpy()
+ object_global_orient = optimized_init_actions_ckpt['object_global_orient'].detach().cpu().numpy()
+
+ print(mano_robot_states.weight.data[1])
+
+ #### TODO: add an arg to control where to save the gt-reference-data ####
+ sv_gt_refereces = {
+ 'mano_glb_rot': mano_glb_rotation_np_data,
+ 'mano_glb_trans': mano_glb_trans_np_data,
+ 'mano_states': mano_states_np_data,
+ 'obj_rot': object_global_orient,
+ 'obj_trans': object_transl
+ }
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_20_cube_data.npy"
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_25_ball_data.npy"
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_54_cylinder_data.npy"
+ sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_1_dingshuji_data.npy"
+ np.save(sv_gt_refereces_fn, sv_gt_refereces)
+ print(f'gt reference data saved to {sv_gt_refereces_fn}')
+ #### TODO: add an arg to control where to save the gt-reference-data ####
+
+def scale_and_save_meshes(meshes_folder):
+ minn_robo_pts = -0.1
+ maxx_robo_pts = 0.2
+ extent_robo_pts = maxx_robo_pts - minn_robo_pts
+ mult_const_after_cent = 0.437551664260203 ## should modify
+
+ mult_const_after_cent = mult_const_after_cent / 3. * 0.9507
+
+ meshes_fn = os.listdir(meshes_folder)
+ meshes_fn = [fn for fn in meshes_fn if fn.endswith(".obj") and "scaled" not in fn]
+ for cur_fn in meshes_fn:
+ cur_mesh_name = cur_fn.split(".")[0]
+ print(f"cur_mesh_name: {cur_mesh_name}")
+ scaled_mesh_name = cur_mesh_name + "_scaled_bullet.obj"
+ full_mesh_fn = os.path.join(meshes_folder, cur_fn)
+ scaled_mesh_fn = os.path.join(meshes_folder, scaled_mesh_name)
+ try:
+ cur_mesh = trimesh.load_mesh(full_mesh_fn)
+ except:
+ continue
+ cur_mesh.vertices = cur_mesh.vertices
+
+ if 'palm' in cur_mesh_name:
+ cur_mesh.vertices = (cur_mesh.vertices - minn_robo_pts) / extent_robo_pts
+ cur_mesh.vertices = cur_mesh.vertices * 2. -1.
+ cur_mesh.vertices = cur_mesh.vertices * mult_const_after_cent # mult_const #
+ else:
+ cur_mesh.vertices = (cur_mesh.vertices) / extent_robo_pts
+ cur_mesh.vertices = cur_mesh.vertices * 2. # -1.
+ cur_mesh.vertices = cur_mesh.vertices * mult_const_after_cent # mult_const #
+
+ cur_mesh.export(scaled_mesh_fn)
+ print(f"scaled_mesh_fn: {scaled_mesh_fn}")
+ exit(0)
+
+def scale_and_save_meshes_v2(meshes_folder):
+ # /home/xueyi/diffsim/NeuS/rsc/redmax_hand/meshes/hand/body0_centered_scaled_v2.obj
+ for body_idx in range(0, 18):
+ cur_body_mesh_fn = f"body{body_idx}_centered_scaled_v2.obj"
+ cur_body_mesh_fn = os.path.join(meshes_folder, cur_body_mesh_fn)
+ cur_body_rescaled_mesh_fn = f"body{body_idx}_centered_scaled_v2_rescaled_grab.obj"
+ cur_body_rescaled_mesh_fn = os.path.join(meshes_folder, cur_body_rescaled_mesh_fn)
+ cur_mesh = trimesh.load_mesh(cur_body_mesh_fn)
+
+ cur_mesh.vertices = cur_mesh.vertices / 4.0
+ cur_mesh.export(cur_body_rescaled_mesh_fn)
+
+ # minn_robo_pts = -0.1
+ # maxx_robo_pts = 0.2
+ # extent_robo_pts = maxx_robo_pts - minn_robo_pts
+ # mult_const_after_cent = 0.437551664260203 ## should modify
+
+ # mult_const_after_cent = mult_const_after_cent / 3. * 0.9507
+
+ # meshes_fn = os.listdir(meshes_folder)
+ # meshes_fn = [fn for fn in meshes_fn if fn.endswith(".obj") and "scaled" not in fn]
+ # for cur_fn in meshes_fn:
+ # cur_mesh_name = cur_fn.split(".")[0]
+ # print(f"cur_mesh_name: {cur_mesh_name}")
+ # scaled_mesh_name = cur_mesh_name + "_scaled_bullet.obj"
+ # full_mesh_fn = os.path.join(meshes_folder, cur_fn)
+ # scaled_mesh_fn = os.path.join(meshes_folder, scaled_mesh_name)
+ # try:
+ # cur_mesh = trimesh.load_mesh(full_mesh_fn)
+ # except:
+ # continue
+ # cur_mesh.vertices = cur_mesh.vertices
+
+ # if 'palm' in cur_mesh_name:
+ # cur_mesh.vertices = (cur_mesh.vertices - minn_robo_pts) / extent_robo_pts
+ # cur_mesh.vertices = cur_mesh.vertices * 2. -1.
+ # cur_mesh.vertices = cur_mesh.vertices * mult_const_after_cent # mult_const #
+ # else:
+ # cur_mesh.vertices = (cur_mesh.vertices) / extent_robo_pts
+ # cur_mesh.vertices = cur_mesh.vertices * 2. # -1.
+ # cur_mesh.vertices = cur_mesh.vertices * mult_const_after_cent # mult_const #
+
+ # cur_mesh.export(scaled_mesh_fn)
+ # print(f"scaled_mesh_fn: {scaled_mesh_fn}")
+ exit(0)
+
+
+def calibreate_urdf_files_v2(urdf_fn):
+ # active_robot = parse_data_from_urdf(xml_fn)
+ active_robot = parse_data_from_urdf(urdf_fn)
+ tot_joints = active_robot.tot_joints
+ tot_links = active_robot.link_name_to_link_struct
+
+ minn_robo_pts = -0.1
+ maxx_robo_pts = 0.2
+ extent_robo_pts = maxx_robo_pts - minn_robo_pts
+ mult_const_after_cent = 0.437551664260203 ## should modify
+
+ mult_const_after_cent = mult_const_after_cent / 3. * 0.9507
+
+ with open(urdf_fn) as rf:
+ urdf_string = rf.read()
+ for cur_joint in tot_joints:
+ # print(f"type: {cur_joint.type}, origin: {cur_joint.origin_xyz}")
+ cur_joint_origin = cur_joint.origin_xyz
+
+ # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ if cur_joint.name in ['FFJ4' , 'MFJ4' ,'RFJ4' ,'LFJ5' ,'THJ5']:
+ cur_joint_origin = (cur_joint_origin - minn_robo_pts) / extent_robo_pts
+ cur_joint_origin = cur_joint_origin * 2.0 - 1.0
+ cur_joint_origin = cur_joint_origin * mult_const_after_cent
+ else:
+ cur_joint_origin = (cur_joint_origin) / extent_robo_pts
+ cur_joint_origin = cur_joint_origin * 2.0 # - 1.0
+ cur_joint_origin = cur_joint_origin * mult_const_after_cent
+
+
+ origin_list = cur_joint_origin.detach().cpu().tolist()
+ origin_list = [str(cur_val) for cur_val in origin_list]
+ origin_str = " ".join(origin_list)
+ print(f"name: {cur_joint.name}, cur_joint_origin: {origin_str}")
+
+ # scaled_joint_origin = cur_joint_origin * 3.
+ # cur_joint_origin_string = cur_joint.origin_xyz_string
+ # if len(cur_joint_origin_string) == 0 or torch.sum(cur_joint_origin).item() == 0.:
+ # continue
+ # #
+ # cur_joint_origin_string_wtag = ""
+ # scaled_joint_origin_string_wtag = ""
+ # # scaled_joint_origin_string = f"{scaled_joint_origin[0].item()} {scaled_joint_origin[1].item()} {scaled_joint_origin[2].item()}"
+ # # urdf_string = urdf_string.replace(cur_joint_origin_string, scaled_joint_origin_string)
+ # urdf_string = urdf_string.replace(cur_joint_origin_string_wtag, scaled_joint_origin_string_wtag)
+ # changed_urdf_fn = urdf_fn.replace(".urdf", "_scaled.urdf")
+ # with open(changed_urdf_fn, "w") as wf:
+ # wf.write(urdf_string)
+ # print(f"changed_urdf_fn: {changed_urdf_fn}")
+ # # exit(0)
+
+ # for cur_link_nm in tot_links:
+ # cur_link = tot_links[cur_link_nm]
+ # if cur_link.visual is None:
+ # continue
+ # xyz_visual = cur_link.visual.visual_xyz
+ # xyz_visual = (xyz_visual / extent_robo_pts) * 2.0 * mult_const_after_cent
+ # xyz_visual_list = xyz_visual.detach().cpu().tolist()
+ # xyz_visual_list = [str(cur_val) for cur_val in xyz_visual_list]
+ # xyz_visual_str = " ".join(xyz_visual_list)
+ # print(f"name: {cur_link.name}, xyz_visual: {xyz_visual_str}")
+
+def get_shadow_GT_states_data_from_ckpt(ckpt_fn):
+ mano_nn_substeps = 1
+ num_steps = 60
+
+
+ # robot actions # #
+ # robot_actions = nn.Embedding(
+ # num_embeddings=num_steps, embedding_dim=22,
+ # ).cuda()
+ # torch.nn.init.zeros_(robot_actions.weight)
+ # # params_to_train += list(robot_actions.parameters())
+
+ # robot_delta_states = nn.Embedding(
+ # num_embeddings=num_steps, embedding_dim=60,
+ # ).cuda()
+ # torch.nn.init.zeros_(robot_delta_states.weight)
+ # # params_to_train += list(robot_delta_states.parameters())
+
+
+ robot_states = nn.Embedding(
+ num_embeddings=num_steps, embedding_dim=60,
+ ).cuda()
+ torch.nn.init.zeros_(robot_states.weight)
+ # params_to_train += list(robot_states.parameters())
+
+ # robot_init_states = nn.Embedding(
+ # num_embeddings=1, embedding_dim=22,
+ # ).cuda()
+ # torch.nn.init.zeros_(robot_init_states.weight)
+ # # params_to_train += list(robot_init_states.parameters())
+
+ robot_glb_rotation = nn.Embedding(
+ num_embeddings=num_steps, embedding_dim=4
+ ).cuda()
+ robot_glb_rotation.weight.data[:, 0] = 1.
+ robot_glb_rotation.weight.data[:, 1:] = 0.
+
+
+ robot_glb_trans = nn.Embedding(
+ num_embeddings=num_steps, embedding_dim=3
+ ).cuda()
+ torch.nn.init.zeros_(robot_glb_trans.weight)
+
+ ''' Load optimized MANO hand actions and states '''
+ cur_optimized_init_actions_fn = ckpt_fn
+ optimized_init_actions_ckpt = torch.load(cur_optimized_init_actions_fn, map_location='cpu', )
+
+ print(f"optimized_init_actions_ckpt: {optimized_init_actions_ckpt.keys()}")
+
+ if 'robot_glb_rotation' in optimized_init_actions_ckpt:
+ robot_glb_rotation.load_state_dict(optimized_init_actions_ckpt['robot_glb_rotation'])
+
+ if 'robot_states' in optimized_init_actions_ckpt:
+ robot_states.load_state_dict(optimized_init_actions_ckpt['robot_states'])
+
+ if 'robot_glb_trans' in optimized_init_actions_ckpt:
+ robot_glb_trans.load_state_dict(optimized_init_actions_ckpt['robot_glb_trans'])
+
+ # if 'mano_robot_glb_trans' in optimized_init_actions_ckpt: # mano_robot_glb_trans
+ # mano_robot_glb_trans.load_state_dict(optimized_init_actions_ckpt['mano_robot_glb_trans'])
+
+ robot_glb_trans_np_data = robot_glb_trans.weight.data.detach().cpu().numpy()
+ robot_glb_rotation_np_data = robot_glb_rotation.weight.data.detach().cpu().numpy()
+ robot_states_np_data = robot_states.weight.data.detach().cpu().numpy()
+
+ if optimized_init_actions_ckpt is not None and 'object_transl' in optimized_init_actions_ckpt:
+ object_transl = optimized_init_actions_ckpt['object_transl'].detach().cpu().numpy()
+ object_global_orient = optimized_init_actions_ckpt['object_global_orient'].detach().cpu().numpy()
+
+ # print(mano_robot_states.weight.data[1])
+
+ #### TODO: add an arg to control where to save the gt-reference-data ####
+ sv_gt_refereces = {
+ 'mano_glb_rot': robot_glb_rotation_np_data,
+ 'mano_glb_trans': robot_glb_trans_np_data,
+ 'mano_states': robot_states_np_data,
+ 'obj_rot': object_global_orient,
+ 'obj_trans': object_transl
+ }
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_20_cube_data.npy"
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_25_ball_data.npy"
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/grab_train_split_54_cylinder_data.npy"
+ # sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/shadow_grab_train_split_224_tiantianquan_data.npy"
+ sv_gt_refereces_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/shadow_grab_train_split_54_cylinder_data.npy"
+ np.save(sv_gt_refereces_fn, sv_gt_refereces)
+ print(f'gt reference data saved to {sv_gt_refereces_fn}')
+ #### TODO: add an arg to control where to save the gt-reference-data ####
+
+## saved the robot file ##
+
+
+def calibrate_left_shadow_hand():
+ rgt_shadow_hand_des_folder = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description"
+ lft_shadow_hand_des_folder = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description_left"
+ os.makedirs(lft_shadow_hand_des_folder, exist_ok=True)
+ lft_shadow_hand_mesh_folder = os.path.join(lft_shadow_hand_des_folder, "meshes")
+ os.makedirs(lft_shadow_hand_mesh_folder, exist_ok=True)
+ rgt_shadow_hand_mesh_folder = os.path.join(rgt_shadow_hand_des_folder, "meshes")
+ tot_rgt_hand_meshes = os.listdir(rgt_shadow_hand_mesh_folder)
+ tot_rgt_hand_meshes = [fn for fn in tot_rgt_hand_meshes if fn.endswith(".obj")]
+ for cur_hand_mesh_fn in tot_rgt_hand_meshes:
+ full_rgt_mesh_fn = os.path.join(rgt_shadow_hand_mesh_folder, cur_hand_mesh_fn)
+ try:
+ full_rgt_mesh = trimesh.load(full_rgt_mesh_fn, force='mesh')
+ except:
+ continue
+ full_rgt_mesh_verts = full_rgt_mesh.vertices
+ full_rgt_mesh_faces = full_rgt_mesh.faces
+ full_rgt_mesh_verts[:, 1] = -1. * full_rgt_mesh_verts[:, 1] ## flip the y-axis
+ lft_mesh = trimesh.Trimesh(vertices=full_rgt_mesh_verts, faces=full_rgt_mesh_faces)
+ lft_mesh_fn = os.path.join(lft_shadow_hand_mesh_folder, cur_hand_mesh_fn)
+ lft_mesh.export(lft_mesh_fn)
+ print(f"lft_mesh_fn: {lft_mesh_fn}")
+ exit(0)
+
+
+## urd for the left hand
+def calibreate_urdf_files_left_hand(urdf_fn):
+ # active_robot = parse_data_from_urdf(xml_fn)
+ active_robot = parse_data_from_urdf(urdf_fn)
+ tot_joints = active_robot.tot_joints
+ tot_links = active_robot.link_name_to_link_struct
+
+ minn_robo_pts = -0.1
+ maxx_robo_pts = 0.2
+ extent_robo_pts = maxx_robo_pts - minn_robo_pts
+ mult_const_after_cent = 0.437551664260203 ## should modify
+
+ mult_const_after_cent = mult_const_after_cent / 3. * 0.9507
+
+ with open(urdf_fn) as rf:
+ urdf_string = rf.read()
+ for cur_joint in tot_joints:
+ # print(f"type: {cur_joint.type}, origin: {cur_joint.origin_xyz}")
+ cur_joint_origin = cur_joint.origin_xyz
+
+ cur_joint_axis = cur_joint.axis_xyz
+
+ cur_joint_origin = cur_joint_origin.detach()
+ cur_joint_axis = cur_joint_axis.detach()
+
+ cur_joint_origin[1] = -1.0 * cur_joint_origin[1]
+ cur_joint_axis[1] = -1.0 * cur_joint_axis[1]
+
+
+ origin_list = cur_joint_origin.detach().cpu().tolist()
+ origin_list = [str(cur_val) for cur_val in origin_list]
+ origin_str = " ".join(origin_list)
+
+ axis_list = cur_joint_axis.detach().cpu().tolist()
+ axis_list = [str(cur_val) for cur_val in axis_list]
+ axis_str = " ".join(axis_list)
+ print(f"name: {cur_joint.name}, cur_joint_origin: {origin_str}, axis_str: {axis_str}")
+
+ # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ # if cur_joint.name in ['FFJ4' , 'MFJ4' ,'RFJ4' ,'LFJ5' ,'THJ5']:
+ # cur_joint_origin = (cur_joint_origin - minn_robo_pts) / extent_robo_pts
+ # cur_joint_origin = cur_joint_origin * 2.0 - 1.0
+ # cur_joint_origin = cur_joint_origin * mult_const_after_cent
+ # else:
+ # cur_joint_origin = (cur_joint_origin) / extent_robo_pts
+ # cur_joint_origin = cur_joint_origin * 2.0 # - 1.0
+ # cur_joint_origin = cur_joint_origin * mult_const_after_cent
+
+
+ # origin_list = cur_joint_origin.detach().cpu().tolist()
+ # origin_list = [str(cur_val) for cur_val in origin_list]
+ # origin_str = " ".join(origin_list)
+ # print(f"name: {cur_joint.name}, cur_joint_origin: {origin_str}")
+
+def calibreate_urdf_files_v4(urdf_fn, dst_urdf_fn):
+ # active_robot = parse_data_from_urdf(xml_fn)
+ active_robot = parse_data_from_urdf(urdf_fn)
+ tot_joints = active_robot.tot_joints
+ tot_links = active_robot.link_name_to_link_struct
+
+ # minn_robo_pts = -0.1
+ # maxx_robo_pts = 0.2
+ # extent_robo_pts = maxx_robo_pts - minn_robo_pts
+ # mult_const_after_cent = 0.437551664260203 ## should modify
+
+ # mult_const_after_cent = mult_const_after_cent / 3. * 0.9507
+
+ with open(urdf_fn) as rf:
+ urdf_string = rf.read()
+ for cur_joint in tot_joints:
+ # print(f"type: {cur_joint.type}, origin: {cur_joint.origin_xyz}")
+ cur_joint_origin = cur_joint.origin_xyz
+ modified_joint_origin = cur_joint_origin / 4.
+
+ origin_list = cur_joint_origin.detach().cpu().tolist()
+ origin_list = [str(cur_val) for cur_val in origin_list]
+ origin_str = " ".join(origin_list)
+
+ dst_list = modified_joint_origin.detach().cpu().tolist()
+ dst_list = [str(cur_val) for cur_val in dst_list]
+ dst_str = " ".join(dst_list)
+
+ urdf_string = urdf_string.replace(origin_str, dst_str)
+
+ with open(dst_urdf_fn, "w") as wf:
+ wf.write(urdf_string)
+ wf.close()
+ # # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ # # cur_joint_origin = (cur_joint_origin / extent_robo_pts) * 2.0 * mult_const_after_cent
+
+ # if cur_joint.name in ['FFJ4' , 'MFJ4' ,'RFJ4' ,'LFJ5' ,'THJ5']:
+ # cur_joint_origin = (cur_joint_origin - minn_robo_pts) / extent_robo_pts
+ # cur_joint_origin = cur_joint_origin * 2.0 - 1.0
+ # cur_joint_origin = cur_joint_origin * mult_const_after_cent
+ # else:
+ # cur_joint_origin = (cur_joint_origin) / extent_robo_pts
+ # cur_joint_origin = cur_joint_origin * 2.0 # - 1.0
+ # cur_joint_origin = cur_joint_origin * mult_const_after_cent
+
+
+ # origin_list = cur_joint_origin.detach().cpu().tolist()
+ # origin_list = [str(cur_val) for cur_val in origin_list]
+ # origin_str = " ".join(origin_list)
+ # print(f"name: {cur_joint.name}, cur_joint_origin: {origin_str}")
+
+
+def test_gt_ref_data(gt_ref_data_fn):
+ cur_gt_ref_data = np.load(gt_ref_data_fn, allow_pickle=True).item()
+ print(cur_gt_ref_data.keys())
+
+ mano_glb_rot, glb_trans, states = cur_gt_ref_data['mano_glb_rot'], cur_gt_ref_data['mano_glb_trans'], cur_gt_ref_data['mano_states']
+ return mano_glb_rot, glb_trans, states
+
+
+def get_states(gt_ref_data_fn):
+ states = np.load(gt_ref_data_fn, allow_pickle=True).item()
+ return states['target']
+
+#### Big TODO: the external contact forces from the manipulated object to the robot ####
+if __name__=='__main__': # # #
+
+ gt_ref_data_fn = "/home/xueyi/diffsim/Control-VAE/Data/ReferenceData/shadow_grab_train_split_85_bunny_wact_data.npy"
+ # mano_glb_rot, glb_trans, states = test_gt_ref_data(gt_ref_data_fn)
+ # eixt(0)
+ mano_states_fn = '/home/xueyi/diffsim/NeuS/raw_data/evalulated_traj_sm_l512_wana_v3_subiters1024_optim_params_shadow_85_bunny_std0d01_netv1_mass10000_new_dp1d0_wtable_gn9d8__step_2.npy'
+ mano_states_fn = '/home/xueyi/diffsim/NeuS/raw_data/evalulated_traj_sm_l512_wana_v3_subiters1024_optim_params_shadow_102_mouse_wact_std0d01_netv1_mass10000_new_dp1d0_dtv2tsv2ctlv2_netv3optt_lstd_langdamp_wcmase_ni4_wtable_gn_adp1d0_trwmonly_cs0d6_predtarmano_wambient__step_9.npy'
+ mano_states = get_states(mano_states_fn)
+
+ blended_ratio = 0.5
+
+ blended_states = []
+
+ tot_rot_mtxes = []
+ tot_trans = []
+ for i_state in range(len(mano_states)):
+ cur_trans = mano_states[i_state][:3]
+ cur_rot = mano_states[i_state][3:6]
+ cur_states = mano_states[i_state][6:]
+
+ cur_rot_struct = R.from_euler('zyx', cur_rot[[2, 1, 0]], degrees=False)
+ cur_rot_mtx = cur_rot_struct.as_matrix()
+
+ tot_rot_mtxes.append(cur_rot_mtx)
+ tot_trans.append(cur_trans)
+
+
+ cur_state = cur_states # states[i_state]
+ cur_modified_state = mano_states[0][6:] + (cur_state - mano_states[0][6:] ) * blended_ratio
+
+ cur_modified_state = np.concatenate([np.zeros((2,), dtype=np.float32), cur_modified_state], axis=-1)
+ blended_states.append(cur_modified_state)
+ # return blended_states
+
+ tot_rot_mtxes = np.stack(tot_rot_mtxes, axis=0)
+ tot_trans = np.stack(tot_trans, axis=0)
+ blended_states = np.stack(blended_states, axis=0)
+
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/redmax_hand_test_3_wcollision.urdf"
+ # dst_urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/redmax_hand_test_3_wcollision_rescaled_grab.urdf"
+ # calibreate_urdf_files_v4(urdf_fn, dst_urdf_fn)
+ # exit(0)
+
+ # meshes_folder = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/meshes/hand"
+ # scale_and_save_meshes_v2(meshes_folder)
+ # exit(0)
+
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled.urdf"
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled_nroot_new.urdf"
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled_nroot.urdf"
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/redmax_hand_test_3_wcollision_rescaled_grab.urdf"
+ robot_agent = RobotAgent(urdf_fn)
+
+ init_vertices, init_faces = robot_agent.active_robot.init_vertices, robot_agent.active_robot.init_faces
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+
+ tot_transformed_pts = []
+ for i_ts in range(len(blended_states)):
+ cur_blended_states = blended_states[i_ts]
+ cur_blended_states = torch.from_numpy(cur_blended_states).float().cuda()
+ robot_agent.active_robot.set_delta_state_and_update_v2(cur_blended_states, 0)
+ cur_pts = robot_agent.get_init_state_visual_pts().detach().cpu().numpy()
+
+ cur_pts_transformed = np.matmul(
+ tot_rot_mtxes[i_ts], cur_pts.T
+ ).T + tot_trans[i_ts][None]
+ tot_transformed_pts.append(cur_pts_transformed)
+ tot_transformed_pts = np.stack(tot_transformed_pts, axis=0)
+ np.save("/home/xueyi/diffsim/NeuS/raw_data/transformed_pts.npy", {'tot_transformed_pts': tot_transformed_pts, 'init_faces': init_faces})
+ exit(0)
+
+
+ robot_agent.active_robot.set_delta_state_and_update_v2()
+
+ init_vertices, init_faces = robot_agent.active_robot.init_vertices, robot_agent.active_robot.init_faces
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+ print(f"init_vertices: {init_vertices.shape}, init_faces: {init_faces.shape}")
+ shadow_hand_mesh = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
+ # shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/shadow_hand_lft.obj"
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/shadow_hand_new.ply"
+ shadow_hand_mesh.export(shadow_hand_sv_fn)
+ np.save("/home/xueyi/diffsim/NeuS/raw_data/faces.npy", init_faces)
+
+ exit(0)
+
+ init_vertices, init_faces = robot_agent.active_robot.init_vertices, robot_agent.active_robot.init_faces
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+ print(f"init_vertices: {init_vertices.shape}, init_faces: {init_faces.shape}")
+ shadow_hand_mesh = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
+ # shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/shadow_hand_lft.obj"
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_shadow_hand.obj"
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab.obj"
+ shadow_hand_mesh.export(shadow_hand_sv_fn)
+
+ init_joint_states = torch.randn((60, ), dtype=torch.float32).cuda()
+ robot_agent.set_initial_state(init_joint_states)
+
+
+ cur_verts, cur_faces = robot_agent.get_init_visual_pts()
+ cur_mesh = trimesh.Trimesh(vertices=cur_verts.detach().cpu().numpy(), faces=cur_faces.detach().cpu().numpy())
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab_wstate.obj"
+ cur_mesh.export(shadow_hand_sv_fn)
+ exit(0)
+
+
+
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled.urdf"
+
+ ##
+ lft_urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description_left/shadowhand_left_new_scaled.urdf"
+
+
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/redmax_hand_test_3_wcollision.urdf"
+
+ ##
+ lft_urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/redmax_hand/redmax_hand_test_3_wcollision.urdf"
+
+ robot_agent = RobotAgent(lft_urdf_fn)
+ init_vertices, init_faces = robot_agent.active_robot.init_vertices, robot_agent.active_robot.init_faces
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+ print(f"init_vertices: {init_vertices.shape}, init_faces: {init_faces.shape}")
+ shadow_hand_mesh = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/shadow_hand_lft.obj"
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/redmax_hand.obj"
+ shadow_hand_mesh.export(shadow_hand_sv_fn)
+ exit(0)
+
+
+ rgt_urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled.urdf"
+ # rgt_urdf_fn
+ calibreate_urdf_files_left_hand(rgt_urdf_fn)
+ exit(0)
+
+ calibrate_left_shadow_hand()
+ exit(0)
+
+ # ckpt_fn = "/data3/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_states_/checkpoints/ckpt_320000.pth"
+ # ckpt_fn = "/data3/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_states_optrobot__seq_54_optrules_/checkpoints/ckpt_030000.pth"
+ # get_shadow_GT_states_data_from_ckpt(ckpt_fn)
+ # exit(0)
+
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new_scaled.urdf"
+ # calibreate_urdf_files_v2(urdf_fn)
+ # exit(0)
+
+ meshes_folder = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/meshes"
+ # scale_and_save_meshes(meshes_folder)
+ # exit(0)
+
+ # sv_ckpt_fn = "/data3/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_mano_states_grab_train_54_cylinder_tst_/checkpoints/ckpt_070000.pth"
+ # sv_ckpt_fn = "/data3/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_mano_states_grab_train_1_dingshuji_tst_/checkpoints/ckpt_070000.pth"
+ # get_GT_states_data_from_ckpt(sv_ckpt_fn)
+ # exit(0)
+
+
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
+ # robot_agent = RobotAgent(urdf_fn)
+ # exit(0)
+
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_wcollision_scaled.urdf"
+ # calibreate_urdf_files(urdf_fn)
+ # exit(0)
+
+ # urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/shadow_hand_description/shadowhand_new.urdf"
+ robot_agent = RobotAgent(urdf_fn)
+
+ init_vertices, init_faces = robot_agent.active_robot.init_vertices, robot_agent.active_robot.init_faces
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+
+ shadow_hand_mesh = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
+ shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/shadow_hand.obj"
+ shadow_hand_mesh.export(shadow_hand_sv_fn)
+ exit(0)
+
+
+ ref_dict_npy = "reference_verts.npy"
+ robot_agent.initialize_optimization(ref_dict_npy)
+ ts_to_robot_points, ts_to_ref_points = robot_agent.forward_stepping_optimization()
+ np.save(f"ts_to_robot_points.npy", ts_to_robot_points)
+ np.save(f"ts_to_ref_points.npy", ts_to_ref_points)
+ exit(0)
+
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
+ cur_robot = parse_data_from_urdf(urdf_fn)
+ # self.init_vertices, self.init_faces
+ init_vertices, init_faces = cur_robot.init_vertices, cur_robot.init_faces
+
+
+
+ init_vertices = init_vertices.detach().cpu().numpy()
+ init_faces = init_faces.detach().cpu().numpy()
+
+
+ ## initial states ehre ##3
+ # mesh_obj = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
+ # mesh_obj.export(f"hand_urdf.ply")
+
+ ##### Test the set initial state function #####
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
+ cur_robot.set_initial_state(init_joint_states)
+ ##### Test the set initial state function #####
+
+
+
+
+ cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32).cuda()
+ cur_ones_actions = torch.ones((60, ), dtype=torch.float32).cuda() # * 100
+
+ ts_to_mesh_verts = {}
+ for i_ts in range(50):
+ cur_robot.calculate_inertia()
+
+ cur_robot.set_actions_and_update_states(cur_ones_actions, i_ts, 0.2) ###
+
+
+ cur_verts, cur_faces = cur_robot.get_init_visual_pts()
+ cur_mesh = trimesh.Trimesh(vertices=cur_verts.detach().cpu().numpy(), faces=cur_faces.detach().cpu().numpy())
+
+ ts_to_mesh_verts[i_ts + i_ts] = cur_verts.detach().cpu().numpy()
+ # cur_mesh.export(f"stated_mano_mesh.ply")
+ # cur_mesh.export(f"zero_actioned_mano_mesh.ply")
+ cur_mesh.export(f"ones_actioned_mano_mesh_ts_{i_ts}.ply")
+
+ np.save(f"reference_verts.npy", ts_to_mesh_verts)
+
+ exit(0)
+
+ xml_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere.xml"
+ robot_agent = RobotAgent(xml_fn=xml_fn, args=None)
+ init_visual_pts = robot_agent.init_visual_pts.detach().cpu().numpy()
+ exit(0)
+
\ No newline at end of file