Spaces:
Running
Running
meow
commited on
Commit
•
6f7cc86
1
Parent(s):
710e818
- app.py +1 -1
- exp_runner_stage_1.py +0 -0
- models/dyn_model_act_v2.py +80 -80
- models/fields.py +0 -0
- pre-requirements.txt +1 -1
- requirements.txt +2 -1
- scripts_demo/train_grab_pointset_points_dyn_s1.sh +7 -3
app.py
CHANGED
@@ -149,7 +149,7 @@ def create_demo():
|
|
149 |
inputs = input_file
|
150 |
outputs = output_file
|
151 |
gr.Examples(
|
152 |
-
examples=[os.path.join(os.path.dirname(__file__), "./
|
153 |
inputs=inputs,
|
154 |
fn=predict,
|
155 |
outputs=outputs,
|
|
|
149 |
inputs = input_file
|
150 |
outputs = output_file
|
151 |
gr.Examples(
|
152 |
+
examples=[os.path.join(os.path.dirname(__file__), "./data/102_grab_all_data.npy")],
|
153 |
inputs=inputs,
|
154 |
fn=predict,
|
155 |
outputs=outputs,
|
exp_runner_stage_1.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
models/dyn_model_act_v2.py
CHANGED
@@ -147,7 +147,7 @@ class Inertial:
|
|
147 |
self.mass = mass
|
148 |
self.inertia = inertia
|
149 |
if torch.sum(self.inertia).item() < 1e-4:
|
150 |
-
self.inertia = self.inertia + torch.eye(3, dtype=torch.float32)
|
151 |
pass
|
152 |
|
153 |
class Visual:
|
@@ -189,8 +189,8 @@ class Visual:
|
|
189 |
vertices = mesh.vertices
|
190 |
faces = mesh.faces
|
191 |
|
192 |
-
vertices = torch.from_numpy(vertices).float()
|
193 |
-
faces =torch.from_numpy(faces).long()
|
194 |
|
195 |
vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
|
196 |
|
@@ -244,7 +244,7 @@ class Visual:
|
|
244 |
# sample from the circile with cur_pts as thejcenter and the radius as expand_r
|
245 |
# (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
|
246 |
offset_dist = Uniform(-1. * expand_r, expand_r)
|
247 |
-
offset_vec = offset_dist.sample((nn_expand_pts, 3))
|
248 |
cur_expanded_pts = cur_pts + offset_vec
|
249 |
cur_expanded_visual_pts.append(cur_expanded_pts)
|
250 |
cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
|
@@ -252,7 +252,7 @@ class Visual:
|
|
252 |
else:
|
253 |
print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
|
254 |
cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
|
255 |
-
cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float()
|
256 |
self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
|
257 |
return self.cur_expanded_visual_pts
|
258 |
|
@@ -343,7 +343,7 @@ class Link_urdf:
|
|
343 |
try:
|
344 |
cur_child_inertia = cur_child_struct.cur_inertia
|
345 |
except:
|
346 |
-
cur_child_inertia = torch.eye(3, dtype=torch.float32)
|
347 |
|
348 |
|
349 |
if cur_joint.type in ['revolute'] and (cur_joint_name not in ['WRJ2', 'WRJ1']):
|
@@ -428,10 +428,10 @@ class Link_urdf:
|
|
428 |
joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
|
429 |
cur_joint_idx = cur_child.link_idx
|
430 |
joint_link_idxes = torch.cat(
|
431 |
-
[joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long)
|
432 |
)
|
433 |
else:
|
434 |
-
joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).
|
435 |
|
436 |
# joint link idxes #
|
437 |
|
@@ -446,7 +446,7 @@ class Link_urdf:
|
|
446 |
|
447 |
# joint_origin_xyz = self.joint.origin_xyz # c ## get forces from the expanded point set ##
|
448 |
else:
|
449 |
-
joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
|
450 |
# self.parent_rot_mtx = parent_rot
|
451 |
# self.parent_trans_vec = parent_trans + joint_origin_xyz
|
452 |
|
@@ -455,13 +455,13 @@ class Link_urdf:
|
|
455 |
# ## get init visual meshes ## ## --
|
456 |
init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes, expanded_pts=expanded_pts)
|
457 |
cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
|
458 |
-
cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long)
|
459 |
init_visual_meshes['link_idxes'].append(cur_link_idxes)
|
460 |
|
461 |
# self.link_idx #
|
462 |
if joint_idxes is not None:
|
463 |
cur_idxes = [self.link_idx for _ in range(cur_visual_mesh_pts_nn)]
|
464 |
-
cur_idxes = torch.tensor(cur_idxes, dtype=torch.long)
|
465 |
joint_idxes.append(cur_idxes)
|
466 |
|
467 |
|
@@ -473,7 +473,7 @@ class Link_urdf:
|
|
473 |
# calculate inerti
|
474 |
def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
|
475 |
link_name_to_visited[self.name] = 1
|
476 |
-
self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32)
|
477 |
|
478 |
if self.joint is not None:
|
479 |
for joint_nm in self.joint:
|
@@ -634,7 +634,7 @@ class Link_urdf:
|
|
634 |
try:
|
635 |
cur_child_inertia = cur_child_struct.cur_inertia
|
636 |
except:
|
637 |
-
cur_child_inertia = torch.eye(3, dtype=torch.float32)
|
638 |
|
639 |
|
640 |
if cur_joint.type in ['revolute']:
|
@@ -663,18 +663,18 @@ class Link_urdf:
|
|
663 |
self.joint_angle = init_states[self.joint.joint_idx]
|
664 |
joint_axis = self.joint.axis
|
665 |
self.rot_vec = self.joint_angle * joint_axis
|
666 |
-
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
|
667 |
self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
|
668 |
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
669 |
-
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).
|
670 |
for cur_link in self.children:
|
671 |
cur_link.set_init_states_target_value(init_states)
|
672 |
|
673 |
# should forward for one single step -> use the action #
|
674 |
def set_init_states(self, ):
|
675 |
-
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
|
676 |
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
677 |
-
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).
|
678 |
for cur_link in self.children:
|
679 |
cur_link.set_init_states()
|
680 |
|
@@ -737,31 +737,31 @@ class Joint_urdf: #
|
|
737 |
|
738 |
#### only for the current state #### # joint urdf #
|
739 |
self.state = nn.Parameter(
|
740 |
-
torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True)
|
741 |
)
|
742 |
self.action = nn.Parameter(
|
743 |
-
torch.zeros((1,), dtype=torch.float32, requires_grad=True)
|
744 |
)
|
745 |
# self.rot_mtx = np.eye(3, dtypes=np.float32)
|
746 |
# self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
|
747 |
-
self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True)
|
748 |
-
self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True)
|
749 |
|
750 |
def set_initial_state(self, state):
|
751 |
# joint angle as the state value #
|
752 |
-
self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).
|
753 |
delta_rot_vec = self.axis_xyz * state
|
754 |
# self.timestep_to_states[0] = state.detach()
|
755 |
-
cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
756 |
init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
|
757 |
self.timestep_to_states[0] = init_state.detach()
|
758 |
self.state = init_state
|
759 |
|
760 |
def set_delta_state_and_update(self, state, cur_timestep):
|
761 |
-
self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).
|
762 |
delta_rot_vec = self.axis_xyz * state
|
763 |
if cur_timestep == 0:
|
764 |
-
prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
765 |
else:
|
766 |
# prev_state = self.timestep_to_states[cur_timestep - 1].detach()
|
767 |
prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
|
@@ -771,7 +771,7 @@ class Joint_urdf: #
|
|
771 |
|
772 |
|
773 |
def set_delta_state_and_update_v2(self, delta_state, cur_timestep):
|
774 |
-
self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).
|
775 |
|
776 |
if cur_timestep == 0:
|
777 |
cur_state = delta_state
|
@@ -787,12 +787,12 @@ class Joint_urdf: #
|
|
787 |
|
788 |
cur_rot_vec = self.axis_xyz * cur_state ### cur_state #### #
|
789 |
# angle to the quaternion ? #
|
790 |
-
init_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
791 |
cur_quat_state = init_state + update_quaternion(cur_rot_vec, init_state)
|
792 |
self.state = cur_quat_state
|
793 |
|
794 |
# if cur_timestep == 0:
|
795 |
-
# prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
796 |
# else:
|
797 |
# # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
|
798 |
# prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
|
@@ -816,8 +816,8 @@ class Joint_urdf: #
|
|
816 |
self.rot_mtx = rot_mtx
|
817 |
self.trans_vec = trans_vec
|
818 |
elif self.type == "fixed":
|
819 |
-
rot_mtx = torch.eye(3, dtype=torch.float32)
|
820 |
-
trans_vec = torch.zeros((3,), dtype=torch.float32)
|
821 |
# trans_vec = self.origin_xyz
|
822 |
self.rot_mtx = rot_mtx
|
823 |
self.trans_vec = trans_vec #
|
@@ -840,7 +840,7 @@ class Joint_urdf: #
|
|
840 |
torque = self.action * self.axis_xyz
|
841 |
|
842 |
# # Compute inertia matrix #
|
843 |
-
# inertial = torch.zeros((3, 3), dtype=torch.float32)
|
844 |
# for i_pts in range(self.visual_pts.size(0)):
|
845 |
# cur_pts = self.visual_pts[i_pts]
|
846 |
# cur_pts_mass = self.visual_pts_mass[i_pts]
|
@@ -848,7 +848,7 @@ class Joint_urdf: #
|
|
848 |
# # cur_vert = init_passive_mesh[i_v]
|
849 |
# # cur_r = cur_vert - init_passive_mesh_center
|
850 |
# dot_r_r = torch.sum(cur_r * cur_r)
|
851 |
-
# cur_eye_mtx = torch.eye(3, dtype=torch.float32)
|
852 |
# r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
|
853 |
# inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
|
854 |
# m = torch.sum(self.visual_pts_mass)
|
@@ -935,7 +935,7 @@ class Joint_urdf: #
|
|
935 |
torque = self.action * self.axis_xyz
|
936 |
|
937 |
# # Compute inertia matrix #
|
938 |
-
# inertial = torch.zeros((3, 3), dtype=torch.float32)
|
939 |
# for i_pts in range(self.visual_pts.size(0)):
|
940 |
# cur_pts = self.visual_pts[i_pts]
|
941 |
# cur_pts_mass = self.visual_pts_mass[i_pts]
|
@@ -943,7 +943,7 @@ class Joint_urdf: #
|
|
943 |
# # cur_vert = init_passive_mesh[i_v]
|
944 |
# # cur_r = cur_vert - init_passive_mesh_center
|
945 |
# dot_r_r = torch.sum(cur_r * cur_r)
|
946 |
-
# cur_eye_mtx = torch.eye(3, dtype=torch.float32)
|
947 |
# r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
|
948 |
# inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
|
949 |
# m = torch.sum(self.visual_pts_mass)
|
@@ -957,7 +957,7 @@ class Joint_urdf: #
|
|
957 |
|
958 |
# inertia_inv = torch.linalg.inv(cur_inertia).detach()
|
959 |
|
960 |
-
inertia_inv = torch.eye(n=3, dtype=torch.float32)
|
961 |
|
962 |
|
963 |
|
@@ -994,14 +994,14 @@ class Joint_urdf: #
|
|
994 |
|
995 |
# if cur_timestep
|
996 |
if cur_timestep == 0:
|
997 |
-
self.timestep_to_states[cur_timestep] = torch.zeros((1,), dtype=torch.float32)
|
998 |
cur_state = self.timestep_to_states[cur_timestep].detach()
|
999 |
nex_state = cur_state + delta_state
|
1000 |
# nex_state = nex_state + penetration_delta_state
|
1001 |
## state rot vector along axis ## ## get the pentrated froces -- calulaterot qj
|
1002 |
state_rot_vec_along_axis = nex_state * self.axis_xyz
|
1003 |
### state in the rotation vector -> state in quaternion ###
|
1004 |
-
state_rot_quat = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
1005 |
### state
|
1006 |
self.state = state_rot_quat
|
1007 |
### get states? ##
|
@@ -1060,7 +1060,7 @@ class Joint_urdf: #
|
|
1060 |
# penetration_delta_state = forces_torques_dot_axis
|
1061 |
else:
|
1062 |
penetration_delta_state = 0.0
|
1063 |
-
cur_joint_maximal_forces = torch.zeros((6,), dtype=torch.float32)
|
1064 |
cur_joint_idx = joint_idx
|
1065 |
joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
|
1066 |
|
@@ -1095,7 +1095,7 @@ class Robot_urdf:
|
|
1095 |
# #
|
1096 |
self.act_joint_idxes = list(self.actions_joint_name_to_joint_idx.values())
|
1097 |
self.act_joint_idxes = sorted(self.act_joint_idxes, reverse=False)
|
1098 |
-
self.act_joint_idxes = torch.tensor(self.act_joint_idxes, dtype=torch.long)
|
1099 |
|
1100 |
self.real_actions_joint_name_to_joint_idx = real_actions_joint_name_to_joint_idx
|
1101 |
|
@@ -1136,8 +1136,8 @@ class Robot_urdf:
|
|
1136 |
init_visual_meshes = {
|
1137 |
'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
|
1138 |
}
|
1139 |
-
init_parent_rot = torch.eye(3, dtype=torch.float32)
|
1140 |
-
init_parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1141 |
|
1142 |
palm_idx = self.link_name_to_link_idxes["palm"]
|
1143 |
palm_link = self.links[palm_idx]
|
@@ -1174,8 +1174,8 @@ class Robot_urdf:
|
|
1174 |
action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
|
1175 |
# print(f"action_joint_name_to_joint_idx: {action_joint_name_to_joint_idx}")
|
1176 |
|
1177 |
-
parent_rot = torch.eye(3, dtype=torch.float32)
|
1178 |
-
parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1179 |
|
1180 |
# cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
|
1181 |
|
@@ -1251,8 +1251,8 @@ class Robot_urdf:
|
|
1251 |
|
1252 |
link_name_to_visited = {}
|
1253 |
|
1254 |
-
# parent_rot = torch.eye(3, dtype=torch.float32)
|
1255 |
-
# parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1256 |
## set actions ## #
|
1257 |
# set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
|
1258 |
# self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
|
@@ -1271,8 +1271,8 @@ class Robot_urdf:
|
|
1271 |
|
1272 |
link_name_to_visited = {}
|
1273 |
|
1274 |
-
parent_rot = torch.eye(3, dtype=torch.float32)
|
1275 |
-
parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1276 |
## set actions ## #
|
1277 |
# set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
|
1278 |
# self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
|
@@ -1363,7 +1363,7 @@ def parse_nparray_from_string(strr, args=None):
|
|
1363 |
vals = np.array(vals, dtype=np.float32)
|
1364 |
vals = torch.from_numpy(vals).float()
|
1365 |
## vals ##
|
1366 |
-
vals = nn.Parameter(vals
|
1367 |
|
1368 |
return vals
|
1369 |
|
@@ -1486,7 +1486,7 @@ def parse_link_data_urdf(link):
|
|
1486 |
inertial_izz = inertial_inertia.attrib["izz"]
|
1487 |
inertial_izz = float(inertial_izz)
|
1488 |
|
1489 |
-
inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32)
|
1490 |
inertial_inertia_mtx[0, 0] = inertial_ixx
|
1491 |
inertial_inertia_mtx[0, 1] = inertial_ixy
|
1492 |
inertial_inertia_mtx[0, 2] = inertial_ixz
|
@@ -1547,7 +1547,7 @@ def parse_joint_data_urdf(joint):
|
|
1547 |
origin_xyz_string = joint_origin.attrib["xyz"]
|
1548 |
origin_xyz = parse_nparray_from_string(origin_xyz_string)
|
1549 |
except:
|
1550 |
-
origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
|
1551 |
origin_xyz_string = ""
|
1552 |
|
1553 |
joint_axis = joint.find("./axis")
|
@@ -1555,7 +1555,7 @@ def parse_joint_data_urdf(joint):
|
|
1555 |
joint_axis = joint_axis.attrib["xyz"]
|
1556 |
joint_axis = parse_nparray_from_string(joint_axis)
|
1557 |
else:
|
1558 |
-
joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32)
|
1559 |
|
1560 |
joint_limit = joint.find("./limit")
|
1561 |
if joint_limit is not None:
|
@@ -1746,13 +1746,13 @@ class RobotAgent: # robot and the robot #
|
|
1746 |
|
1747 |
self.time_constant = nn.Embedding(
|
1748 |
num_embeddings=3, embedding_dim=1
|
1749 |
-
)
|
1750 |
torch.nn.init.ones_(self.time_constant.weight) #
|
1751 |
self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
|
1752 |
|
1753 |
self.optimizable_actions = nn.Embedding(
|
1754 |
num_embeddings=100, embedding_dim=60,
|
1755 |
-
)
|
1756 |
torch.nn.init.zeros_(self.optimizable_actions.weight) #
|
1757 |
|
1758 |
self.learning_rate = 5e-4
|
@@ -1770,24 +1770,24 @@ class RobotAgent: # robot and the robot #
|
|
1770 |
|
1771 |
|
1772 |
def set_init_states_target_value(self, init_states):
|
1773 |
-
# glb_rot = torch.eye(n=3, dtype=torch.float32)
|
1774 |
-
# glb_trans = torch.zeros((3,), dtype=torch.float32)
|
1775 |
|
1776 |
# tot_init_states = {}
|
1777 |
# tot_init_states['glb_rot'] = glb_rot;
|
1778 |
# tot_init_states['glb_trans'] = glb_trans;
|
1779 |
# tot_init_states['links_init_states'] = init_states
|
1780 |
# self.active_robot.set_init_states_target_value(tot_init_states)
|
1781 |
-
# init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
1782 |
self.active_robot.set_initial_state(init_states)
|
1783 |
|
1784 |
def set_init_states(self):
|
1785 |
-
# glb_rot = torch.eye(n=3, dtype=torch.float32)
|
1786 |
-
# glb_trans = torch.zeros((3,), dtype=torch.float32)
|
1787 |
|
1788 |
# ### random rotation ###
|
1789 |
# # glb_rot_np = R.random().as_matrix()
|
1790 |
-
# # glb_rot = torch.from_numpy(glb_rot_np).float()
|
1791 |
# ### random rotation ###
|
1792 |
|
1793 |
# # glb_rot, glb_trans #
|
@@ -1796,7 +1796,7 @@ class RobotAgent: # robot and the robot #
|
|
1796 |
# init_states['glb_trans'] = glb_trans;
|
1797 |
# self.active_robot.set_init_states(init_states)
|
1798 |
|
1799 |
-
init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
1800 |
self.active_robot.set_initial_state(init_joint_states)
|
1801 |
|
1802 |
# cur_verts, joint_idxes = get_init_state_visual_pts(expanded_pts=False, ret_joint_idxes=True)
|
@@ -1821,13 +1821,13 @@ class RobotAgent: # robot and the robot #
|
|
1821 |
|
1822 |
def set_actions_and_update_states(self, actions, cur_timestep):
|
1823 |
#
|
1824 |
-
time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)
|
1825 |
self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
|
1826 |
pass
|
1827 |
|
1828 |
def set_actions_and_update_states_v2(self, actions, cur_timestep, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
|
1829 |
#
|
1830 |
-
time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)
|
1831 |
self.active_robot.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes) ###
|
1832 |
pass
|
1833 |
|
@@ -1841,9 +1841,9 @@ class RobotAgent: # robot and the robot #
|
|
1841 |
timestep_to_visual_pts = {}
|
1842 |
for i_step in range(50):
|
1843 |
actions = {}
|
1844 |
-
actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
|
1845 |
-
actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
|
1846 |
-
actions_link_actions = torch.ones((22, ), dtype=torch.float32)
|
1847 |
# actions_link_actions = actions_link_actions * 0.2
|
1848 |
actions_link_actions = actions_link_actions * -1. #
|
1849 |
actions['link_actions'] = actions_link_actions
|
@@ -1862,7 +1862,7 @@ class RobotAgent: # robot and the robot #
|
|
1862 |
# TODO: load reference points #
|
1863 |
self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
|
1864 |
self.ts_to_reference_pts = {
|
1865 |
-
ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float()
|
1866 |
}
|
1867 |
|
1868 |
|
@@ -1879,9 +1879,9 @@ class RobotAgent: # robot and the robot #
|
|
1879 |
for cur_ts in range(self.n_timesteps):
|
1880 |
# print(f"iter: {i_iter}, cur_ts: {cur_ts}")
|
1881 |
# actions = {}
|
1882 |
-
# actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
|
1883 |
-
# actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
|
1884 |
-
actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long)
|
1885 |
# actions_link_actions = actions_link_actions * 0.2
|
1886 |
# actions_link_actions = actions_link_actions * -1. #
|
1887 |
# actions['link_actions'] = actions_link_actions
|
@@ -2240,39 +2240,39 @@ def get_shadow_GT_states_data_from_ckpt(ckpt_fn):
|
|
2240 |
# robot actions # #
|
2241 |
# robot_actions = nn.Embedding(
|
2242 |
# num_embeddings=num_steps, embedding_dim=22,
|
2243 |
-
# )
|
2244 |
# torch.nn.init.zeros_(robot_actions.weight)
|
2245 |
# # params_to_train += list(robot_actions.parameters())
|
2246 |
|
2247 |
# robot_delta_states = nn.Embedding(
|
2248 |
# num_embeddings=num_steps, embedding_dim=60,
|
2249 |
-
# )
|
2250 |
# torch.nn.init.zeros_(robot_delta_states.weight)
|
2251 |
# # params_to_train += list(robot_delta_states.parameters())
|
2252 |
|
2253 |
|
2254 |
robot_states = nn.Embedding(
|
2255 |
num_embeddings=num_steps, embedding_dim=60,
|
2256 |
-
)
|
2257 |
torch.nn.init.zeros_(robot_states.weight)
|
2258 |
# params_to_train += list(robot_states.parameters())
|
2259 |
|
2260 |
# robot_init_states = nn.Embedding(
|
2261 |
# num_embeddings=1, embedding_dim=22,
|
2262 |
-
# )
|
2263 |
# torch.nn.init.zeros_(robot_init_states.weight)
|
2264 |
# # params_to_train += list(robot_init_states.parameters())
|
2265 |
|
2266 |
robot_glb_rotation = nn.Embedding(
|
2267 |
num_embeddings=num_steps, embedding_dim=4
|
2268 |
-
)
|
2269 |
robot_glb_rotation.weight.data[:, 0] = 1.
|
2270 |
robot_glb_rotation.weight.data[:, 1:] = 0.
|
2271 |
|
2272 |
|
2273 |
robot_glb_trans = nn.Embedding(
|
2274 |
num_embeddings=num_steps, embedding_dim=3
|
2275 |
-
)
|
2276 |
torch.nn.init.zeros_(robot_glb_trans.weight)
|
2277 |
|
2278 |
''' Load optimized MANO hand actions and states '''
|
@@ -2531,7 +2531,7 @@ if __name__=='__main__': # # #
|
|
2531 |
tot_transformed_pts = []
|
2532 |
for i_ts in range(len(blended_states)):
|
2533 |
cur_blended_states = blended_states[i_ts]
|
2534 |
-
cur_blended_states = torch.from_numpy(cur_blended_states).float()
|
2535 |
robot_agent.active_robot.set_delta_state_and_update_v2(cur_blended_states, 0)
|
2536 |
cur_pts = robot_agent.get_init_state_visual_pts().detach().cpu().numpy()
|
2537 |
|
@@ -2568,7 +2568,7 @@ if __name__=='__main__': # # #
|
|
2568 |
shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab.obj"
|
2569 |
shadow_hand_mesh.export(shadow_hand_sv_fn)
|
2570 |
|
2571 |
-
init_joint_states = torch.randn((60, ), dtype=torch.float32)
|
2572 |
robot_agent.set_initial_state(init_joint_states)
|
2573 |
|
2574 |
|
@@ -2676,15 +2676,15 @@ if __name__=='__main__': # # #
|
|
2676 |
# mesh_obj.export(f"hand_urdf.ply")
|
2677 |
|
2678 |
##### Test the set initial state function #####
|
2679 |
-
init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
2680 |
cur_robot.set_initial_state(init_joint_states)
|
2681 |
##### Test the set initial state function #####
|
2682 |
|
2683 |
|
2684 |
|
2685 |
|
2686 |
-
cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32)
|
2687 |
-
cur_ones_actions = torch.ones((60, ), dtype=torch.float32)
|
2688 |
|
2689 |
ts_to_mesh_verts = {}
|
2690 |
for i_ts in range(50):
|
|
|
147 |
self.mass = mass
|
148 |
self.inertia = inertia
|
149 |
if torch.sum(self.inertia).item() < 1e-4:
|
150 |
+
self.inertia = self.inertia + torch.eye(3, dtype=torch.float32)
|
151 |
pass
|
152 |
|
153 |
class Visual:
|
|
|
189 |
vertices = mesh.vertices
|
190 |
faces = mesh.faces
|
191 |
|
192 |
+
vertices = torch.from_numpy(vertices).float()
|
193 |
+
faces =torch.from_numpy(faces).long()
|
194 |
|
195 |
vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
|
196 |
|
|
|
244 |
# sample from the circile with cur_pts as thejcenter and the radius as expand_r
|
245 |
# (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
|
246 |
offset_dist = Uniform(-1. * expand_r, expand_r)
|
247 |
+
offset_vec = offset_dist.sample((nn_expand_pts, 3))
|
248 |
cur_expanded_pts = cur_pts + offset_vec
|
249 |
cur_expanded_visual_pts.append(cur_expanded_pts)
|
250 |
cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
|
|
|
252 |
else:
|
253 |
print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
|
254 |
cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
|
255 |
+
cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float()
|
256 |
self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
|
257 |
return self.cur_expanded_visual_pts
|
258 |
|
|
|
343 |
try:
|
344 |
cur_child_inertia = cur_child_struct.cur_inertia
|
345 |
except:
|
346 |
+
cur_child_inertia = torch.eye(3, dtype=torch.float32)
|
347 |
|
348 |
|
349 |
if cur_joint.type in ['revolute'] and (cur_joint_name not in ['WRJ2', 'WRJ1']):
|
|
|
428 |
joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
|
429 |
cur_joint_idx = cur_child.link_idx
|
430 |
joint_link_idxes = torch.cat(
|
431 |
+
[joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long)], dim=-1
|
432 |
)
|
433 |
else:
|
434 |
+
joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).view(1,)
|
435 |
|
436 |
# joint link idxes #
|
437 |
|
|
|
446 |
|
447 |
# joint_origin_xyz = self.joint.origin_xyz # c ## get forces from the expanded point set ##
|
448 |
else:
|
449 |
+
joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
|
450 |
# self.parent_rot_mtx = parent_rot
|
451 |
# self.parent_trans_vec = parent_trans + joint_origin_xyz
|
452 |
|
|
|
455 |
# ## get init visual meshes ## ## --
|
456 |
init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes, expanded_pts=expanded_pts)
|
457 |
cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
|
458 |
+
cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long)+ self.link_idx
|
459 |
init_visual_meshes['link_idxes'].append(cur_link_idxes)
|
460 |
|
461 |
# self.link_idx #
|
462 |
if joint_idxes is not None:
|
463 |
cur_idxes = [self.link_idx for _ in range(cur_visual_mesh_pts_nn)]
|
464 |
+
cur_idxes = torch.tensor(cur_idxes, dtype=torch.long)
|
465 |
joint_idxes.append(cur_idxes)
|
466 |
|
467 |
|
|
|
473 |
# calculate inerti
|
474 |
def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
|
475 |
link_name_to_visited[self.name] = 1
|
476 |
+
self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32)
|
477 |
|
478 |
if self.joint is not None:
|
479 |
for joint_nm in self.joint:
|
|
|
634 |
try:
|
635 |
cur_child_inertia = cur_child_struct.cur_inertia
|
636 |
except:
|
637 |
+
cur_child_inertia = torch.eye(3, dtype=torch.float32)
|
638 |
|
639 |
|
640 |
if cur_joint.type in ['revolute']:
|
|
|
663 |
self.joint_angle = init_states[self.joint.joint_idx]
|
664 |
joint_axis = self.joint.axis
|
665 |
self.rot_vec = self.joint_angle * joint_axis
|
666 |
+
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
|
667 |
self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
|
668 |
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
669 |
+
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
|
670 |
for cur_link in self.children:
|
671 |
cur_link.set_init_states_target_value(init_states)
|
672 |
|
673 |
# should forward for one single step -> use the action #
|
674 |
def set_init_states(self, ):
|
675 |
+
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
|
676 |
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
677 |
+
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
|
678 |
for cur_link in self.children:
|
679 |
cur_link.set_init_states()
|
680 |
|
|
|
737 |
|
738 |
#### only for the current state #### # joint urdf #
|
739 |
self.state = nn.Parameter(
|
740 |
+
torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True), requires_grad=True
|
741 |
)
|
742 |
self.action = nn.Parameter(
|
743 |
+
torch.zeros((1,), dtype=torch.float32, requires_grad=True), requires_grad=True
|
744 |
)
|
745 |
# self.rot_mtx = np.eye(3, dtypes=np.float32)
|
746 |
# self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
|
747 |
+
self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True), requires_grad=True)
|
748 |
+
self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True), requires_grad=True)
|
749 |
|
750 |
def set_initial_state(self, state):
|
751 |
# joint angle as the state value #
|
752 |
+
self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
|
753 |
delta_rot_vec = self.axis_xyz * state
|
754 |
# self.timestep_to_states[0] = state.detach()
|
755 |
+
cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
756 |
init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
|
757 |
self.timestep_to_states[0] = init_state.detach()
|
758 |
self.state = init_state
|
759 |
|
760 |
def set_delta_state_and_update(self, state, cur_timestep):
|
761 |
+
self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).detach()
|
762 |
delta_rot_vec = self.axis_xyz * state
|
763 |
if cur_timestep == 0:
|
764 |
+
prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
765 |
else:
|
766 |
# prev_state = self.timestep_to_states[cur_timestep - 1].detach()
|
767 |
prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
|
|
|
771 |
|
772 |
|
773 |
def set_delta_state_and_update_v2(self, delta_state, cur_timestep):
|
774 |
+
self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).detach()
|
775 |
|
776 |
if cur_timestep == 0:
|
777 |
cur_state = delta_state
|
|
|
787 |
|
788 |
cur_rot_vec = self.axis_xyz * cur_state ### cur_state #### #
|
789 |
# angle to the quaternion ? #
|
790 |
+
init_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
791 |
cur_quat_state = init_state + update_quaternion(cur_rot_vec, init_state)
|
792 |
self.state = cur_quat_state
|
793 |
|
794 |
# if cur_timestep == 0:
|
795 |
+
# prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
|
796 |
# else:
|
797 |
# # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
|
798 |
# prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
|
|
|
816 |
self.rot_mtx = rot_mtx
|
817 |
self.trans_vec = trans_vec
|
818 |
elif self.type == "fixed":
|
819 |
+
rot_mtx = torch.eye(3, dtype=torch.float32)
|
820 |
+
trans_vec = torch.zeros((3,), dtype=torch.float32)
|
821 |
# trans_vec = self.origin_xyz
|
822 |
self.rot_mtx = rot_mtx
|
823 |
self.trans_vec = trans_vec #
|
|
|
840 |
torque = self.action * self.axis_xyz
|
841 |
|
842 |
# # Compute inertia matrix #
|
843 |
+
# inertial = torch.zeros((3, 3), dtype=torch.float32)
|
844 |
# for i_pts in range(self.visual_pts.size(0)):
|
845 |
# cur_pts = self.visual_pts[i_pts]
|
846 |
# cur_pts_mass = self.visual_pts_mass[i_pts]
|
|
|
848 |
# # cur_vert = init_passive_mesh[i_v]
|
849 |
# # cur_r = cur_vert - init_passive_mesh_center
|
850 |
# dot_r_r = torch.sum(cur_r * cur_r)
|
851 |
+
# cur_eye_mtx = torch.eye(3, dtype=torch.float32)
|
852 |
# r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
|
853 |
# inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
|
854 |
# m = torch.sum(self.visual_pts_mass)
|
|
|
935 |
torque = self.action * self.axis_xyz
|
936 |
|
937 |
# # Compute inertia matrix #
|
938 |
+
# inertial = torch.zeros((3, 3), dtype=torch.float32)
|
939 |
# for i_pts in range(self.visual_pts.size(0)):
|
940 |
# cur_pts = self.visual_pts[i_pts]
|
941 |
# cur_pts_mass = self.visual_pts_mass[i_pts]
|
|
|
943 |
# # cur_vert = init_passive_mesh[i_v]
|
944 |
# # cur_r = cur_vert - init_passive_mesh_center
|
945 |
# dot_r_r = torch.sum(cur_r * cur_r)
|
946 |
+
# cur_eye_mtx = torch.eye(3, dtype=torch.float32)
|
947 |
# r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
|
948 |
# inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
|
949 |
# m = torch.sum(self.visual_pts_mass)
|
|
|
957 |
|
958 |
# inertia_inv = torch.linalg.inv(cur_inertia).detach()
|
959 |
|
960 |
+
inertia_inv = torch.eye(n=3, dtype=torch.float32)
|
961 |
|
962 |
|
963 |
|
|
|
994 |
|
995 |
# if cur_timestep
|
996 |
if cur_timestep == 0:
|
997 |
+
self.timestep_to_states[cur_timestep] = torch.zeros((1,), dtype=torch.float32)
|
998 |
cur_state = self.timestep_to_states[cur_timestep].detach()
|
999 |
nex_state = cur_state + delta_state
|
1000 |
# nex_state = nex_state + penetration_delta_state
|
1001 |
## state rot vector along axis ## ## get the pentrated froces -- calulaterot qj
|
1002 |
state_rot_vec_along_axis = nex_state * self.axis_xyz
|
1003 |
### state in the rotation vector -> state in quaternion ###
|
1004 |
+
state_rot_quat = torch.tensor([1., 0., 0., 0.], dtype=torch.float32) + update_quaternion(state_rot_vec_along_axis, torch.tensor([1., 0., 0., 0.], dtype=torch.float32))
|
1005 |
### state
|
1006 |
self.state = state_rot_quat
|
1007 |
### get states? ##
|
|
|
1060 |
# penetration_delta_state = forces_torques_dot_axis
|
1061 |
else:
|
1062 |
penetration_delta_state = 0.0
|
1063 |
+
cur_joint_maximal_forces = torch.zeros((6,), dtype=torch.float32)
|
1064 |
cur_joint_idx = joint_idx
|
1065 |
joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
|
1066 |
|
|
|
1095 |
# #
|
1096 |
self.act_joint_idxes = list(self.actions_joint_name_to_joint_idx.values())
|
1097 |
self.act_joint_idxes = sorted(self.act_joint_idxes, reverse=False)
|
1098 |
+
self.act_joint_idxes = torch.tensor(self.act_joint_idxes, dtype=torch.long)[2:]
|
1099 |
|
1100 |
self.real_actions_joint_name_to_joint_idx = real_actions_joint_name_to_joint_idx
|
1101 |
|
|
|
1136 |
init_visual_meshes = {
|
1137 |
'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
|
1138 |
}
|
1139 |
+
init_parent_rot = torch.eye(3, dtype=torch.float32)
|
1140 |
+
init_parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1141 |
|
1142 |
palm_idx = self.link_name_to_link_idxes["palm"]
|
1143 |
palm_link = self.links[palm_idx]
|
|
|
1174 |
action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
|
1175 |
# print(f"action_joint_name_to_joint_idx: {action_joint_name_to_joint_idx}")
|
1176 |
|
1177 |
+
parent_rot = torch.eye(3, dtype=torch.float32)
|
1178 |
+
parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1179 |
|
1180 |
# cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
|
1181 |
|
|
|
1251 |
|
1252 |
link_name_to_visited = {}
|
1253 |
|
1254 |
+
# parent_rot = torch.eye(3, dtype=torch.float32)
|
1255 |
+
# parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1256 |
## set actions ## #
|
1257 |
# set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
|
1258 |
# self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
|
|
|
1271 |
|
1272 |
link_name_to_visited = {}
|
1273 |
|
1274 |
+
parent_rot = torch.eye(3, dtype=torch.float32)
|
1275 |
+
parent_trans = torch.zeros((3,), dtype=torch.float32)
|
1276 |
## set actions ## #
|
1277 |
# set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
|
1278 |
# self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
|
|
|
1363 |
vals = np.array(vals, dtype=np.float32)
|
1364 |
vals = torch.from_numpy(vals).float()
|
1365 |
## vals ##
|
1366 |
+
vals = nn.Parameter(vals, requires_grad=True)
|
1367 |
|
1368 |
return vals
|
1369 |
|
|
|
1486 |
inertial_izz = inertial_inertia.attrib["izz"]
|
1487 |
inertial_izz = float(inertial_izz)
|
1488 |
|
1489 |
+
inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32)
|
1490 |
inertial_inertia_mtx[0, 0] = inertial_ixx
|
1491 |
inertial_inertia_mtx[0, 1] = inertial_ixy
|
1492 |
inertial_inertia_mtx[0, 2] = inertial_ixz
|
|
|
1547 |
origin_xyz_string = joint_origin.attrib["xyz"]
|
1548 |
origin_xyz = parse_nparray_from_string(origin_xyz_string)
|
1549 |
except:
|
1550 |
+
origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
|
1551 |
origin_xyz_string = ""
|
1552 |
|
1553 |
joint_axis = joint.find("./axis")
|
|
|
1555 |
joint_axis = joint_axis.attrib["xyz"]
|
1556 |
joint_axis = parse_nparray_from_string(joint_axis)
|
1557 |
else:
|
1558 |
+
joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32)
|
1559 |
|
1560 |
joint_limit = joint.find("./limit")
|
1561 |
if joint_limit is not None:
|
|
|
1746 |
|
1747 |
self.time_constant = nn.Embedding(
|
1748 |
num_embeddings=3, embedding_dim=1
|
1749 |
+
)
|
1750 |
torch.nn.init.ones_(self.time_constant.weight) #
|
1751 |
self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
|
1752 |
|
1753 |
self.optimizable_actions = nn.Embedding(
|
1754 |
num_embeddings=100, embedding_dim=60,
|
1755 |
+
)
|
1756 |
torch.nn.init.zeros_(self.optimizable_actions.weight) #
|
1757 |
|
1758 |
self.learning_rate = 5e-4
|
|
|
1770 |
|
1771 |
|
1772 |
def set_init_states_target_value(self, init_states):
|
1773 |
+
# glb_rot = torch.eye(n=3, dtype=torch.float32)
|
1774 |
+
# glb_trans = torch.zeros((3,), dtype=torch.float32) ### glb_trans #### and the rot 3##
|
1775 |
|
1776 |
# tot_init_states = {}
|
1777 |
# tot_init_states['glb_rot'] = glb_rot;
|
1778 |
# tot_init_states['glb_trans'] = glb_trans;
|
1779 |
# tot_init_states['links_init_states'] = init_states
|
1780 |
# self.active_robot.set_init_states_target_value(tot_init_states)
|
1781 |
+
# init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
1782 |
self.active_robot.set_initial_state(init_states)
|
1783 |
|
1784 |
def set_init_states(self):
|
1785 |
+
# glb_rot = torch.eye(n=3, dtype=torch.float32)
|
1786 |
+
# glb_trans = torch.zeros((3,), dtype=torch.float32) ### glb_trans #### and the rot 3##
|
1787 |
|
1788 |
# ### random rotation ###
|
1789 |
# # glb_rot_np = R.random().as_matrix()
|
1790 |
+
# # glb_rot = torch.from_numpy(glb_rot_np).float()
|
1791 |
# ### random rotation ###
|
1792 |
|
1793 |
# # glb_rot, glb_trans #
|
|
|
1796 |
# init_states['glb_trans'] = glb_trans;
|
1797 |
# self.active_robot.set_init_states(init_states)
|
1798 |
|
1799 |
+
init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
1800 |
self.active_robot.set_initial_state(init_joint_states)
|
1801 |
|
1802 |
# cur_verts, joint_idxes = get_init_state_visual_pts(expanded_pts=False, ret_joint_idxes=True)
|
|
|
1821 |
|
1822 |
def set_actions_and_update_states(self, actions, cur_timestep):
|
1823 |
#
|
1824 |
+
time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)) ### time constant of the system ##
|
1825 |
self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
|
1826 |
pass
|
1827 |
|
1828 |
def set_actions_and_update_states_v2(self, actions, cur_timestep, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
|
1829 |
#
|
1830 |
+
time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)) ### time constant of the system ##
|
1831 |
self.active_robot.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes) ###
|
1832 |
pass
|
1833 |
|
|
|
1841 |
timestep_to_visual_pts = {}
|
1842 |
for i_step in range(50):
|
1843 |
actions = {}
|
1844 |
+
actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
|
1845 |
+
actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
|
1846 |
+
actions_link_actions = torch.ones((22, ), dtype=torch.float32)
|
1847 |
# actions_link_actions = actions_link_actions * 0.2
|
1848 |
actions_link_actions = actions_link_actions * -1. #
|
1849 |
actions['link_actions'] = actions_link_actions
|
|
|
1862 |
# TODO: load reference points #
|
1863 |
self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
|
1864 |
self.ts_to_reference_pts = {
|
1865 |
+
ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float() for ts in self.ts_to_reference_pts
|
1866 |
}
|
1867 |
|
1868 |
|
|
|
1879 |
for cur_ts in range(self.n_timesteps):
|
1880 |
# print(f"iter: {i_iter}, cur_ts: {cur_ts}")
|
1881 |
# actions = {}
|
1882 |
+
# actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
|
1883 |
+
# actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
|
1884 |
+
actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long) + cur_ts).squeeze(0)
|
1885 |
# actions_link_actions = actions_link_actions * 0.2
|
1886 |
# actions_link_actions = actions_link_actions * -1. #
|
1887 |
# actions['link_actions'] = actions_link_actions
|
|
|
2240 |
# robot actions # #
|
2241 |
# robot_actions = nn.Embedding(
|
2242 |
# num_embeddings=num_steps, embedding_dim=22,
|
2243 |
+
# )
|
2244 |
# torch.nn.init.zeros_(robot_actions.weight)
|
2245 |
# # params_to_train += list(robot_actions.parameters())
|
2246 |
|
2247 |
# robot_delta_states = nn.Embedding(
|
2248 |
# num_embeddings=num_steps, embedding_dim=60,
|
2249 |
+
# )
|
2250 |
# torch.nn.init.zeros_(robot_delta_states.weight)
|
2251 |
# # params_to_train += list(robot_delta_states.parameters())
|
2252 |
|
2253 |
|
2254 |
robot_states = nn.Embedding(
|
2255 |
num_embeddings=num_steps, embedding_dim=60,
|
2256 |
+
)
|
2257 |
torch.nn.init.zeros_(robot_states.weight)
|
2258 |
# params_to_train += list(robot_states.parameters())
|
2259 |
|
2260 |
# robot_init_states = nn.Embedding(
|
2261 |
# num_embeddings=1, embedding_dim=22,
|
2262 |
+
# )
|
2263 |
# torch.nn.init.zeros_(robot_init_states.weight)
|
2264 |
# # params_to_train += list(robot_init_states.parameters())
|
2265 |
|
2266 |
robot_glb_rotation = nn.Embedding(
|
2267 |
num_embeddings=num_steps, embedding_dim=4
|
2268 |
+
)
|
2269 |
robot_glb_rotation.weight.data[:, 0] = 1.
|
2270 |
robot_glb_rotation.weight.data[:, 1:] = 0.
|
2271 |
|
2272 |
|
2273 |
robot_glb_trans = nn.Embedding(
|
2274 |
num_embeddings=num_steps, embedding_dim=3
|
2275 |
+
)
|
2276 |
torch.nn.init.zeros_(robot_glb_trans.weight)
|
2277 |
|
2278 |
''' Load optimized MANO hand actions and states '''
|
|
|
2531 |
tot_transformed_pts = []
|
2532 |
for i_ts in range(len(blended_states)):
|
2533 |
cur_blended_states = blended_states[i_ts]
|
2534 |
+
cur_blended_states = torch.from_numpy(cur_blended_states).float()
|
2535 |
robot_agent.active_robot.set_delta_state_and_update_v2(cur_blended_states, 0)
|
2536 |
cur_pts = robot_agent.get_init_state_visual_pts().detach().cpu().numpy()
|
2537 |
|
|
|
2568 |
shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab.obj"
|
2569 |
shadow_hand_mesh.export(shadow_hand_sv_fn)
|
2570 |
|
2571 |
+
init_joint_states = torch.randn((60, ), dtype=torch.float32)
|
2572 |
robot_agent.set_initial_state(init_joint_states)
|
2573 |
|
2574 |
|
|
|
2676 |
# mesh_obj.export(f"hand_urdf.ply")
|
2677 |
|
2678 |
##### Test the set initial state function #####
|
2679 |
+
init_joint_states = torch.zeros((60, ), dtype=torch.float32)
|
2680 |
cur_robot.set_initial_state(init_joint_states)
|
2681 |
##### Test the set initial state function #####
|
2682 |
|
2683 |
|
2684 |
|
2685 |
|
2686 |
+
cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32)
|
2687 |
+
cur_ones_actions = torch.ones((60, ), dtype=torch.float32) # * 100
|
2688 |
|
2689 |
ts_to_mesh_verts = {}
|
2690 |
for i_ts in range(50):
|
models/fields.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pre-requirements.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
pip==23.3.2
|
2 |
-
torch==2.2.0
|
|
|
1 |
pip==23.3.2
|
2 |
+
# torch==2.2.0
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-f https://download.pytorch.org/whl/cpu/torch_stable.html
|
2 |
-
-f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
|
|
|
3 |
# pip==20.2.4
|
4 |
torch==2.2.0
|
5 |
# torchvision==0.13.1
|
|
|
1 |
-f https://download.pytorch.org/whl/cpu/torch_stable.html
|
2 |
+
# -f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
|
3 |
+
-i https://download.pytorch.org/whl/cpu
|
4 |
# pip==20.2.4
|
5 |
torch==2.2.0
|
6 |
# torchvision==0.13.1
|
scripts_demo/train_grab_pointset_points_dyn_s1.sh
CHANGED
@@ -18,11 +18,15 @@ export conf=dyn_grab_pointset_points_dyn_s1.conf
|
|
18 |
export conf_root="./confs_new"
|
19 |
|
20 |
|
21 |
-
export
|
22 |
-
|
|
|
|
|
23 |
|
24 |
export cuda_ids="0"
|
25 |
|
26 |
|
27 |
-
CUDA_VISIBLE_DEVICES=${cuda_ids}
|
|
|
|
|
28 |
|
|
|
18 |
export conf_root="./confs_new"
|
19 |
|
20 |
|
21 |
+
export data_fn="./data/102_grab_all_data.npy"
|
22 |
+
|
23 |
+
|
24 |
+
# bash scripts_demo/train_grab_pointset_points_dyn_s1.sh
|
25 |
|
26 |
export cuda_ids="0"
|
27 |
|
28 |
|
29 |
+
# CUDA_VISIBLE_DEVICES=${cuda_ids}
|
30 |
+
|
31 |
+
python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case} --data_fn=${data_fn}
|
32 |
|