meow commited on
Commit
6f7cc86
1 Parent(s): 710e818
app.py CHANGED
@@ -149,7 +149,7 @@ def create_demo():
149
  inputs = input_file
150
  outputs = output_file
151
  gr.Examples(
152
- examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")],
153
  inputs=inputs,
154
  fn=predict,
155
  outputs=outputs,
 
149
  inputs = input_file
150
  outputs = output_file
151
  gr.Examples(
152
+ examples=[os.path.join(os.path.dirname(__file__), "./data/102_grab_all_data.npy")],
153
  inputs=inputs,
154
  fn=predict,
155
  outputs=outputs,
exp_runner_stage_1.py CHANGED
The diff for this file is too large to render. See raw diff
 
models/dyn_model_act_v2.py CHANGED
@@ -147,7 +147,7 @@ class Inertial:
147
  self.mass = mass
148
  self.inertia = inertia
149
  if torch.sum(self.inertia).item() < 1e-4:
150
- self.inertia = self.inertia + torch.eye(3, dtype=torch.float32).cuda()
151
  pass
152
 
153
  class Visual:
@@ -189,8 +189,8 @@ class Visual:
189
  vertices = mesh.vertices
190
  faces = mesh.faces
191
 
192
- vertices = torch.from_numpy(vertices).float().cuda()
193
- faces =torch.from_numpy(faces).long().cuda()
194
 
195
  vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
196
 
@@ -244,7 +244,7 @@ class Visual:
244
  # sample from the circile with cur_pts as thejcenter and the radius as expand_r
245
  # (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
246
  offset_dist = Uniform(-1. * expand_r, expand_r)
247
- offset_vec = offset_dist.sample((nn_expand_pts, 3)).cuda()
248
  cur_expanded_pts = cur_pts + offset_vec
249
  cur_expanded_visual_pts.append(cur_expanded_pts)
250
  cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
@@ -252,7 +252,7 @@ class Visual:
252
  else:
253
  print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
254
  cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
255
- cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float().cuda()
256
  self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
257
  return self.cur_expanded_visual_pts
258
 
@@ -343,7 +343,7 @@ class Link_urdf:
343
  try:
344
  cur_child_inertia = cur_child_struct.cur_inertia
345
  except:
346
- cur_child_inertia = torch.eye(3, dtype=torch.float32).cuda()
347
 
348
 
349
  if cur_joint.type in ['revolute'] and (cur_joint_name not in ['WRJ2', 'WRJ1']):
@@ -428,10 +428,10 @@ class Link_urdf:
428
  joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
429
  cur_joint_idx = cur_child.link_idx
430
  joint_link_idxes = torch.cat(
431
- [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
432
  )
433
  else:
434
- joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
435
 
436
  # joint link idxes #
437
 
@@ -446,7 +446,7 @@ class Link_urdf:
446
 
447
  # joint_origin_xyz = self.joint.origin_xyz # c ## get forces from the expanded point set ##
448
  else:
449
- joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
450
  # self.parent_rot_mtx = parent_rot
451
  # self.parent_trans_vec = parent_trans + joint_origin_xyz
452
 
@@ -455,13 +455,13 @@ class Link_urdf:
455
  # ## get init visual meshes ## ## --
456
  init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes, expanded_pts=expanded_pts)
457
  cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
458
- cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long).cuda()+ self.link_idx
459
  init_visual_meshes['link_idxes'].append(cur_link_idxes)
460
 
461
  # self.link_idx #
462
  if joint_idxes is not None:
463
  cur_idxes = [self.link_idx for _ in range(cur_visual_mesh_pts_nn)]
464
- cur_idxes = torch.tensor(cur_idxes, dtype=torch.long).cuda()
465
  joint_idxes.append(cur_idxes)
466
 
467
 
@@ -473,7 +473,7 @@ class Link_urdf:
473
  # calculate inerti
474
  def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
475
  link_name_to_visited[self.name] = 1
476
- self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32).cuda()
477
 
478
  if self.joint is not None:
479
  for joint_nm in self.joint:
@@ -634,7 +634,7 @@ class Link_urdf:
634
  try:
635
  cur_child_inertia = cur_child_struct.cur_inertia
636
  except:
637
- cur_child_inertia = torch.eye(3, dtype=torch.float32).cuda()
638
 
639
 
640
  if cur_joint.type in ['revolute']:
@@ -663,18 +663,18 @@ class Link_urdf:
663
  self.joint_angle = init_states[self.joint.joint_idx]
664
  joint_axis = self.joint.axis
665
  self.rot_vec = self.joint_angle * joint_axis
666
- self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
667
  self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
668
  self.joint.timestep_to_states[0] = self.joint.state.detach()
669
- self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
670
  for cur_link in self.children:
671
  cur_link.set_init_states_target_value(init_states)
672
 
673
  # should forward for one single step -> use the action #
674
  def set_init_states(self, ):
675
- self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
676
  self.joint.timestep_to_states[0] = self.joint.state.detach()
677
- self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
678
  for cur_link in self.children:
679
  cur_link.set_init_states()
680
 
@@ -737,31 +737,31 @@ class Joint_urdf: #
737
 
738
  #### only for the current state #### # joint urdf #
739
  self.state = nn.Parameter(
740
- torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
741
  )
742
  self.action = nn.Parameter(
743
- torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
744
  )
745
  # self.rot_mtx = np.eye(3, dtypes=np.float32)
746
  # self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
747
- self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
748
- self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
749
 
750
  def set_initial_state(self, state):
751
  # joint angle as the state value #
752
- self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
753
  delta_rot_vec = self.axis_xyz * state
754
  # self.timestep_to_states[0] = state.detach()
755
- cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
756
  init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
757
  self.timestep_to_states[0] = init_state.detach()
758
  self.state = init_state
759
 
760
  def set_delta_state_and_update(self, state, cur_timestep):
761
- self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
762
  delta_rot_vec = self.axis_xyz * state
763
  if cur_timestep == 0:
764
- prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
765
  else:
766
  # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
767
  prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
@@ -771,7 +771,7 @@ class Joint_urdf: #
771
 
772
 
773
  def set_delta_state_and_update_v2(self, delta_state, cur_timestep):
774
- self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
775
 
776
  if cur_timestep == 0:
777
  cur_state = delta_state
@@ -787,12 +787,12 @@ class Joint_urdf: #
787
 
788
  cur_rot_vec = self.axis_xyz * cur_state ### cur_state #### #
789
  # angle to the quaternion ? #
790
- init_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
791
  cur_quat_state = init_state + update_quaternion(cur_rot_vec, init_state)
792
  self.state = cur_quat_state
793
 
794
  # if cur_timestep == 0:
795
- # prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
796
  # else:
797
  # # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
798
  # prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
@@ -816,8 +816,8 @@ class Joint_urdf: #
816
  self.rot_mtx = rot_mtx
817
  self.trans_vec = trans_vec
818
  elif self.type == "fixed":
819
- rot_mtx = torch.eye(3, dtype=torch.float32).cuda()
820
- trans_vec = torch.zeros((3,), dtype=torch.float32).cuda()
821
  # trans_vec = self.origin_xyz
822
  self.rot_mtx = rot_mtx
823
  self.trans_vec = trans_vec #
@@ -840,7 +840,7 @@ class Joint_urdf: #
840
  torque = self.action * self.axis_xyz
841
 
842
  # # Compute inertia matrix #
843
- # inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
844
  # for i_pts in range(self.visual_pts.size(0)):
845
  # cur_pts = self.visual_pts[i_pts]
846
  # cur_pts_mass = self.visual_pts_mass[i_pts]
@@ -848,7 +848,7 @@ class Joint_urdf: #
848
  # # cur_vert = init_passive_mesh[i_v]
849
  # # cur_r = cur_vert - init_passive_mesh_center
850
  # dot_r_r = torch.sum(cur_r * cur_r)
851
- # cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
852
  # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
853
  # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
854
  # m = torch.sum(self.visual_pts_mass)
@@ -935,7 +935,7 @@ class Joint_urdf: #
935
  torque = self.action * self.axis_xyz
936
 
937
  # # Compute inertia matrix #
938
- # inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
939
  # for i_pts in range(self.visual_pts.size(0)):
940
  # cur_pts = self.visual_pts[i_pts]
941
  # cur_pts_mass = self.visual_pts_mass[i_pts]
@@ -943,7 +943,7 @@ class Joint_urdf: #
943
  # # cur_vert = init_passive_mesh[i_v]
944
  # # cur_r = cur_vert - init_passive_mesh_center
945
  # dot_r_r = torch.sum(cur_r * cur_r)
946
- # cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
947
  # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
948
  # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
949
  # m = torch.sum(self.visual_pts_mass)
@@ -957,7 +957,7 @@ class Joint_urdf: #
957
 
958
  # inertia_inv = torch.linalg.inv(cur_inertia).detach()
959
 
960
- inertia_inv = torch.eye(n=3, dtype=torch.float32).cuda()
961
 
962
 
963
 
@@ -994,14 +994,14 @@ class Joint_urdf: #
994
 
995
  # if cur_timestep
996
  if cur_timestep == 0:
997
- self.timestep_to_states[cur_timestep] = torch.zeros((1,), dtype=torch.float32).cuda()
998
  cur_state = self.timestep_to_states[cur_timestep].detach()
999
  nex_state = cur_state + delta_state
1000
  # nex_state = nex_state + penetration_delta_state
1001
  ## state rot vector along axis ## ## get the pentrated froces -- calulaterot qj
1002
  state_rot_vec_along_axis = nex_state * self.axis_xyz
1003
  ### state in the rotation vector -> state in quaternion ###
1004
- state_rot_quat = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda() + update_quaternion(state_rot_vec_along_axis, torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda())
1005
  ### state
1006
  self.state = state_rot_quat
1007
  ### get states? ##
@@ -1060,7 +1060,7 @@ class Joint_urdf: #
1060
  # penetration_delta_state = forces_torques_dot_axis
1061
  else:
1062
  penetration_delta_state = 0.0
1063
- cur_joint_maximal_forces = torch.zeros((6,), dtype=torch.float32).cuda()
1064
  cur_joint_idx = joint_idx
1065
  joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
1066
 
@@ -1095,7 +1095,7 @@ class Robot_urdf:
1095
  # #
1096
  self.act_joint_idxes = list(self.actions_joint_name_to_joint_idx.values())
1097
  self.act_joint_idxes = sorted(self.act_joint_idxes, reverse=False)
1098
- self.act_joint_idxes = torch.tensor(self.act_joint_idxes, dtype=torch.long).cuda()[2:]
1099
 
1100
  self.real_actions_joint_name_to_joint_idx = real_actions_joint_name_to_joint_idx
1101
 
@@ -1136,8 +1136,8 @@ class Robot_urdf:
1136
  init_visual_meshes = {
1137
  'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
1138
  }
1139
- init_parent_rot = torch.eye(3, dtype=torch.float32).cuda()
1140
- init_parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
1141
 
1142
  palm_idx = self.link_name_to_link_idxes["palm"]
1143
  palm_link = self.links[palm_idx]
@@ -1174,8 +1174,8 @@ class Robot_urdf:
1174
  action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
1175
  # print(f"action_joint_name_to_joint_idx: {action_joint_name_to_joint_idx}")
1176
 
1177
- parent_rot = torch.eye(3, dtype=torch.float32).cuda()
1178
- parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
1179
 
1180
  # cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
1181
 
@@ -1251,8 +1251,8 @@ class Robot_urdf:
1251
 
1252
  link_name_to_visited = {}
1253
 
1254
- # parent_rot = torch.eye(3, dtype=torch.float32).cuda()
1255
- # parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
1256
  ## set actions ## #
1257
  # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
1258
  # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
@@ -1271,8 +1271,8 @@ class Robot_urdf:
1271
 
1272
  link_name_to_visited = {}
1273
 
1274
- parent_rot = torch.eye(3, dtype=torch.float32).cuda()
1275
- parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
1276
  ## set actions ## #
1277
  # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
1278
  # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
@@ -1363,7 +1363,7 @@ def parse_nparray_from_string(strr, args=None):
1363
  vals = np.array(vals, dtype=np.float32)
1364
  vals = torch.from_numpy(vals).float()
1365
  ## vals ##
1366
- vals = nn.Parameter(vals.cuda(), requires_grad=True)
1367
 
1368
  return vals
1369
 
@@ -1486,7 +1486,7 @@ def parse_link_data_urdf(link):
1486
  inertial_izz = inertial_inertia.attrib["izz"]
1487
  inertial_izz = float(inertial_izz)
1488
 
1489
- inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32).cuda()
1490
  inertial_inertia_mtx[0, 0] = inertial_ixx
1491
  inertial_inertia_mtx[0, 1] = inertial_ixy
1492
  inertial_inertia_mtx[0, 2] = inertial_ixz
@@ -1547,7 +1547,7 @@ def parse_joint_data_urdf(joint):
1547
  origin_xyz_string = joint_origin.attrib["xyz"]
1548
  origin_xyz = parse_nparray_from_string(origin_xyz_string)
1549
  except:
1550
- origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
1551
  origin_xyz_string = ""
1552
 
1553
  joint_axis = joint.find("./axis")
@@ -1555,7 +1555,7 @@ def parse_joint_data_urdf(joint):
1555
  joint_axis = joint_axis.attrib["xyz"]
1556
  joint_axis = parse_nparray_from_string(joint_axis)
1557
  else:
1558
- joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32).cuda()
1559
 
1560
  joint_limit = joint.find("./limit")
1561
  if joint_limit is not None:
@@ -1746,13 +1746,13 @@ class RobotAgent: # robot and the robot #
1746
 
1747
  self.time_constant = nn.Embedding(
1748
  num_embeddings=3, embedding_dim=1
1749
- ).cuda()
1750
  torch.nn.init.ones_(self.time_constant.weight) #
1751
  self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
1752
 
1753
  self.optimizable_actions = nn.Embedding(
1754
  num_embeddings=100, embedding_dim=60,
1755
- ).cuda()
1756
  torch.nn.init.zeros_(self.optimizable_actions.weight) #
1757
 
1758
  self.learning_rate = 5e-4
@@ -1770,24 +1770,24 @@ class RobotAgent: # robot and the robot #
1770
 
1771
 
1772
  def set_init_states_target_value(self, init_states):
1773
- # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
1774
- # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
1775
 
1776
  # tot_init_states = {}
1777
  # tot_init_states['glb_rot'] = glb_rot;
1778
  # tot_init_states['glb_trans'] = glb_trans;
1779
  # tot_init_states['links_init_states'] = init_states
1780
  # self.active_robot.set_init_states_target_value(tot_init_states)
1781
- # init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
1782
  self.active_robot.set_initial_state(init_states)
1783
 
1784
  def set_init_states(self):
1785
- # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
1786
- # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
1787
 
1788
  # ### random rotation ###
1789
  # # glb_rot_np = R.random().as_matrix()
1790
- # # glb_rot = torch.from_numpy(glb_rot_np).float().cuda()
1791
  # ### random rotation ###
1792
 
1793
  # # glb_rot, glb_trans #
@@ -1796,7 +1796,7 @@ class RobotAgent: # robot and the robot #
1796
  # init_states['glb_trans'] = glb_trans;
1797
  # self.active_robot.set_init_states(init_states)
1798
 
1799
- init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
1800
  self.active_robot.set_initial_state(init_joint_states)
1801
 
1802
  # cur_verts, joint_idxes = get_init_state_visual_pts(expanded_pts=False, ret_joint_idxes=True)
@@ -1821,13 +1821,13 @@ class RobotAgent: # robot and the robot #
1821
 
1822
  def set_actions_and_update_states(self, actions, cur_timestep):
1823
  #
1824
- time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
1825
  self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
1826
  pass
1827
 
1828
  def set_actions_and_update_states_v2(self, actions, cur_timestep, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
1829
  #
1830
- time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
1831
  self.active_robot.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes) ###
1832
  pass
1833
 
@@ -1841,9 +1841,9 @@ class RobotAgent: # robot and the robot #
1841
  timestep_to_visual_pts = {}
1842
  for i_step in range(50):
1843
  actions = {}
1844
- actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
1845
- actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
1846
- actions_link_actions = torch.ones((22, ), dtype=torch.float32).cuda()
1847
  # actions_link_actions = actions_link_actions * 0.2
1848
  actions_link_actions = actions_link_actions * -1. #
1849
  actions['link_actions'] = actions_link_actions
@@ -1862,7 +1862,7 @@ class RobotAgent: # robot and the robot #
1862
  # TODO: load reference points #
1863
  self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
1864
  self.ts_to_reference_pts = {
1865
- ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float().cuda() for ts in self.ts_to_reference_pts
1866
  }
1867
 
1868
 
@@ -1879,9 +1879,9 @@ class RobotAgent: # robot and the robot #
1879
  for cur_ts in range(self.n_timesteps):
1880
  # print(f"iter: {i_iter}, cur_ts: {cur_ts}")
1881
  # actions = {}
1882
- # actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
1883
- # actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
1884
- actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long).cuda() + cur_ts).squeeze(0)
1885
  # actions_link_actions = actions_link_actions * 0.2
1886
  # actions_link_actions = actions_link_actions * -1. #
1887
  # actions['link_actions'] = actions_link_actions
@@ -2240,39 +2240,39 @@ def get_shadow_GT_states_data_from_ckpt(ckpt_fn):
2240
  # robot actions # #
2241
  # robot_actions = nn.Embedding(
2242
  # num_embeddings=num_steps, embedding_dim=22,
2243
- # ).cuda()
2244
  # torch.nn.init.zeros_(robot_actions.weight)
2245
  # # params_to_train += list(robot_actions.parameters())
2246
 
2247
  # robot_delta_states = nn.Embedding(
2248
  # num_embeddings=num_steps, embedding_dim=60,
2249
- # ).cuda()
2250
  # torch.nn.init.zeros_(robot_delta_states.weight)
2251
  # # params_to_train += list(robot_delta_states.parameters())
2252
 
2253
 
2254
  robot_states = nn.Embedding(
2255
  num_embeddings=num_steps, embedding_dim=60,
2256
- ).cuda()
2257
  torch.nn.init.zeros_(robot_states.weight)
2258
  # params_to_train += list(robot_states.parameters())
2259
 
2260
  # robot_init_states = nn.Embedding(
2261
  # num_embeddings=1, embedding_dim=22,
2262
- # ).cuda()
2263
  # torch.nn.init.zeros_(robot_init_states.weight)
2264
  # # params_to_train += list(robot_init_states.parameters())
2265
 
2266
  robot_glb_rotation = nn.Embedding(
2267
  num_embeddings=num_steps, embedding_dim=4
2268
- ).cuda()
2269
  robot_glb_rotation.weight.data[:, 0] = 1.
2270
  robot_glb_rotation.weight.data[:, 1:] = 0.
2271
 
2272
 
2273
  robot_glb_trans = nn.Embedding(
2274
  num_embeddings=num_steps, embedding_dim=3
2275
- ).cuda()
2276
  torch.nn.init.zeros_(robot_glb_trans.weight)
2277
 
2278
  ''' Load optimized MANO hand actions and states '''
@@ -2531,7 +2531,7 @@ if __name__=='__main__': # # #
2531
  tot_transformed_pts = []
2532
  for i_ts in range(len(blended_states)):
2533
  cur_blended_states = blended_states[i_ts]
2534
- cur_blended_states = torch.from_numpy(cur_blended_states).float().cuda()
2535
  robot_agent.active_robot.set_delta_state_and_update_v2(cur_blended_states, 0)
2536
  cur_pts = robot_agent.get_init_state_visual_pts().detach().cpu().numpy()
2537
 
@@ -2568,7 +2568,7 @@ if __name__=='__main__': # # #
2568
  shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab.obj"
2569
  shadow_hand_mesh.export(shadow_hand_sv_fn)
2570
 
2571
- init_joint_states = torch.randn((60, ), dtype=torch.float32).cuda()
2572
  robot_agent.set_initial_state(init_joint_states)
2573
 
2574
 
@@ -2676,15 +2676,15 @@ if __name__=='__main__': # # #
2676
  # mesh_obj.export(f"hand_urdf.ply")
2677
 
2678
  ##### Test the set initial state function #####
2679
- init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
2680
  cur_robot.set_initial_state(init_joint_states)
2681
  ##### Test the set initial state function #####
2682
 
2683
 
2684
 
2685
 
2686
- cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32).cuda()
2687
- cur_ones_actions = torch.ones((60, ), dtype=torch.float32).cuda() # * 100
2688
 
2689
  ts_to_mesh_verts = {}
2690
  for i_ts in range(50):
 
147
  self.mass = mass
148
  self.inertia = inertia
149
  if torch.sum(self.inertia).item() < 1e-4:
150
+ self.inertia = self.inertia + torch.eye(3, dtype=torch.float32)
151
  pass
152
 
153
  class Visual:
 
189
  vertices = mesh.vertices
190
  faces = mesh.faces
191
 
192
+ vertices = torch.from_numpy(vertices).float()
193
+ faces =torch.from_numpy(faces).long()
194
 
195
  vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
196
 
 
244
  # sample from the circile with cur_pts as thejcenter and the radius as expand_r
245
  # (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
246
  offset_dist = Uniform(-1. * expand_r, expand_r)
247
+ offset_vec = offset_dist.sample((nn_expand_pts, 3))
248
  cur_expanded_pts = cur_pts + offset_vec
249
  cur_expanded_visual_pts.append(cur_expanded_pts)
250
  cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
 
252
  else:
253
  print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
254
  cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
255
+ cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float()
256
  self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
257
  return self.cur_expanded_visual_pts
258
 
 
343
  try:
344
  cur_child_inertia = cur_child_struct.cur_inertia
345
  except:
346
+ cur_child_inertia = torch.eye(3, dtype=torch.float32)
347
 
348
 
349
  if cur_joint.type in ['revolute'] and (cur_joint_name not in ['WRJ2', 'WRJ1']):
 
428
  joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
429
  cur_joint_idx = cur_child.link_idx
430
  joint_link_idxes = torch.cat(
431
+ [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long)], dim=-1
432
  )
433
  else:
434
+ joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).view(1,)
435
 
436
  # joint link idxes #
437
 
 
446
 
447
  # joint_origin_xyz = self.joint.origin_xyz # c ## get forces from the expanded point set ##
448
  else:
449
+ joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
450
  # self.parent_rot_mtx = parent_rot
451
  # self.parent_trans_vec = parent_trans + joint_origin_xyz
452
 
 
455
  # ## get init visual meshes ## ## --
456
  init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes, expanded_pts=expanded_pts)
457
  cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
458
+ cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long)+ self.link_idx
459
  init_visual_meshes['link_idxes'].append(cur_link_idxes)
460
 
461
  # self.link_idx #
462
  if joint_idxes is not None:
463
  cur_idxes = [self.link_idx for _ in range(cur_visual_mesh_pts_nn)]
464
+ cur_idxes = torch.tensor(cur_idxes, dtype=torch.long)
465
  joint_idxes.append(cur_idxes)
466
 
467
 
 
473
  # calculate inerti
474
  def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
475
  link_name_to_visited[self.name] = 1
476
+ self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32)
477
 
478
  if self.joint is not None:
479
  for joint_nm in self.joint:
 
634
  try:
635
  cur_child_inertia = cur_child_struct.cur_inertia
636
  except:
637
+ cur_child_inertia = torch.eye(3, dtype=torch.float32)
638
 
639
 
640
  if cur_joint.type in ['revolute']:
 
663
  self.joint_angle = init_states[self.joint.joint_idx]
664
  joint_axis = self.joint.axis
665
  self.rot_vec = self.joint_angle * joint_axis
666
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
667
  self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
668
  self.joint.timestep_to_states[0] = self.joint.state.detach()
669
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
670
  for cur_link in self.children:
671
  cur_link.set_init_states_target_value(init_states)
672
 
673
  # should forward for one single step -> use the action #
674
  def set_init_states(self, ):
675
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32)
676
  self.joint.timestep_to_states[0] = self.joint.state.detach()
677
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
678
  for cur_link in self.children:
679
  cur_link.set_init_states()
680
 
 
737
 
738
  #### only for the current state #### # joint urdf #
739
  self.state = nn.Parameter(
740
+ torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True), requires_grad=True
741
  )
742
  self.action = nn.Parameter(
743
+ torch.zeros((1,), dtype=torch.float32, requires_grad=True), requires_grad=True
744
  )
745
  # self.rot_mtx = np.eye(3, dtypes=np.float32)
746
  # self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
747
+ self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True), requires_grad=True)
748
+ self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True), requires_grad=True)
749
 
750
  def set_initial_state(self, state):
751
  # joint angle as the state value #
752
+ self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).detach() ## velocity ##
753
  delta_rot_vec = self.axis_xyz * state
754
  # self.timestep_to_states[0] = state.detach()
755
+ cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
756
  init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
757
  self.timestep_to_states[0] = init_state.detach()
758
  self.state = init_state
759
 
760
  def set_delta_state_and_update(self, state, cur_timestep):
761
+ self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).detach()
762
  delta_rot_vec = self.axis_xyz * state
763
  if cur_timestep == 0:
764
+ prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
765
  else:
766
  # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
767
  prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
 
771
 
772
 
773
  def set_delta_state_and_update_v2(self, delta_state, cur_timestep):
774
+ self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).detach()
775
 
776
  if cur_timestep == 0:
777
  cur_state = delta_state
 
787
 
788
  cur_rot_vec = self.axis_xyz * cur_state ### cur_state #### #
789
  # angle to the quaternion ? #
790
+ init_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
791
  cur_quat_state = init_state + update_quaternion(cur_rot_vec, init_state)
792
  self.state = cur_quat_state
793
 
794
  # if cur_timestep == 0:
795
+ # prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32)
796
  # else:
797
  # # prev_state = self.timestep_to_states[cur_timestep - 1].detach()
798
  # prev_state = self.timestep_to_states[cur_timestep - 1] # .detach() # not detach? #
 
816
  self.rot_mtx = rot_mtx
817
  self.trans_vec = trans_vec
818
  elif self.type == "fixed":
819
+ rot_mtx = torch.eye(3, dtype=torch.float32)
820
+ trans_vec = torch.zeros((3,), dtype=torch.float32)
821
  # trans_vec = self.origin_xyz
822
  self.rot_mtx = rot_mtx
823
  self.trans_vec = trans_vec #
 
840
  torque = self.action * self.axis_xyz
841
 
842
  # # Compute inertia matrix #
843
+ # inertial = torch.zeros((3, 3), dtype=torch.float32)
844
  # for i_pts in range(self.visual_pts.size(0)):
845
  # cur_pts = self.visual_pts[i_pts]
846
  # cur_pts_mass = self.visual_pts_mass[i_pts]
 
848
  # # cur_vert = init_passive_mesh[i_v]
849
  # # cur_r = cur_vert - init_passive_mesh_center
850
  # dot_r_r = torch.sum(cur_r * cur_r)
851
+ # cur_eye_mtx = torch.eye(3, dtype=torch.float32)
852
  # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
853
  # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
854
  # m = torch.sum(self.visual_pts_mass)
 
935
  torque = self.action * self.axis_xyz
936
 
937
  # # Compute inertia matrix #
938
+ # inertial = torch.zeros((3, 3), dtype=torch.float32)
939
  # for i_pts in range(self.visual_pts.size(0)):
940
  # cur_pts = self.visual_pts[i_pts]
941
  # cur_pts_mass = self.visual_pts_mass[i_pts]
 
943
  # # cur_vert = init_passive_mesh[i_v]
944
  # # cur_r = cur_vert - init_passive_mesh_center
945
  # dot_r_r = torch.sum(cur_r * cur_r)
946
+ # cur_eye_mtx = torch.eye(3, dtype=torch.float32)
947
  # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
948
  # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
949
  # m = torch.sum(self.visual_pts_mass)
 
957
 
958
  # inertia_inv = torch.linalg.inv(cur_inertia).detach()
959
 
960
+ inertia_inv = torch.eye(n=3, dtype=torch.float32)
961
 
962
 
963
 
 
994
 
995
  # if cur_timestep
996
  if cur_timestep == 0:
997
+ self.timestep_to_states[cur_timestep] = torch.zeros((1,), dtype=torch.float32)
998
  cur_state = self.timestep_to_states[cur_timestep].detach()
999
  nex_state = cur_state + delta_state
1000
  # nex_state = nex_state + penetration_delta_state
1001
  ## state rot vector along axis ## ## get the pentrated froces -- calulaterot qj
1002
  state_rot_vec_along_axis = nex_state * self.axis_xyz
1003
  ### state in the rotation vector -> state in quaternion ###
1004
+ state_rot_quat = torch.tensor([1., 0., 0., 0.], dtype=torch.float32) + update_quaternion(state_rot_vec_along_axis, torch.tensor([1., 0., 0., 0.], dtype=torch.float32))
1005
  ### state
1006
  self.state = state_rot_quat
1007
  ### get states? ##
 
1060
  # penetration_delta_state = forces_torques_dot_axis
1061
  else:
1062
  penetration_delta_state = 0.0
1063
+ cur_joint_maximal_forces = torch.zeros((6,), dtype=torch.float32)
1064
  cur_joint_idx = joint_idx
1065
  joint_penetration_forces[cur_joint_idx][:] = cur_joint_maximal_forces[:].clone()
1066
 
 
1095
  # #
1096
  self.act_joint_idxes = list(self.actions_joint_name_to_joint_idx.values())
1097
  self.act_joint_idxes = sorted(self.act_joint_idxes, reverse=False)
1098
+ self.act_joint_idxes = torch.tensor(self.act_joint_idxes, dtype=torch.long)[2:]
1099
 
1100
  self.real_actions_joint_name_to_joint_idx = real_actions_joint_name_to_joint_idx
1101
 
 
1136
  init_visual_meshes = {
1137
  'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
1138
  }
1139
+ init_parent_rot = torch.eye(3, dtype=torch.float32)
1140
+ init_parent_trans = torch.zeros((3,), dtype=torch.float32)
1141
 
1142
  palm_idx = self.link_name_to_link_idxes["palm"]
1143
  palm_link = self.links[palm_idx]
 
1174
  action_joint_name_to_joint_idx = self.real_actions_joint_name_to_joint_idx
1175
  # print(f"action_joint_name_to_joint_idx: {action_joint_name_to_joint_idx}")
1176
 
1177
+ parent_rot = torch.eye(3, dtype=torch.float32)
1178
+ parent_trans = torch.zeros((3,), dtype=torch.float32)
1179
 
1180
  # cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct, parent_rot, parent_trans, penetration_forces, sampled_visual_pts_joint_idxes, joint_penetration_forces):
1181
 
 
1251
 
1252
  link_name_to_visited = {}
1253
 
1254
+ # parent_rot = torch.eye(3, dtype=torch.float32)
1255
+ # parent_trans = torch.zeros((3,), dtype=torch.float32)
1256
  ## set actions ## #
1257
  # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
1258
  # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
 
1271
 
1272
  link_name_to_visited = {}
1273
 
1274
+ parent_rot = torch.eye(3, dtype=torch.float32)
1275
+ parent_trans = torch.zeros((3,), dtype=torch.float32)
1276
  ## set actions ## #
1277
  # set_actions_and_update_states_v2(self, action, cur_timestep, time_cons, cur_inertia):
1278
  # self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct ## set and update states ##
 
1363
  vals = np.array(vals, dtype=np.float32)
1364
  vals = torch.from_numpy(vals).float()
1365
  ## vals ##
1366
+ vals = nn.Parameter(vals, requires_grad=True)
1367
 
1368
  return vals
1369
 
 
1486
  inertial_izz = inertial_inertia.attrib["izz"]
1487
  inertial_izz = float(inertial_izz)
1488
 
1489
+ inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32)
1490
  inertial_inertia_mtx[0, 0] = inertial_ixx
1491
  inertial_inertia_mtx[0, 1] = inertial_ixy
1492
  inertial_inertia_mtx[0, 2] = inertial_ixz
 
1547
  origin_xyz_string = joint_origin.attrib["xyz"]
1548
  origin_xyz = parse_nparray_from_string(origin_xyz_string)
1549
  except:
1550
+ origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32)
1551
  origin_xyz_string = ""
1552
 
1553
  joint_axis = joint.find("./axis")
 
1555
  joint_axis = joint_axis.attrib["xyz"]
1556
  joint_axis = parse_nparray_from_string(joint_axis)
1557
  else:
1558
+ joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32)
1559
 
1560
  joint_limit = joint.find("./limit")
1561
  if joint_limit is not None:
 
1746
 
1747
  self.time_constant = nn.Embedding(
1748
  num_embeddings=3, embedding_dim=1
1749
+ )
1750
  torch.nn.init.ones_(self.time_constant.weight) #
1751
  self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
1752
 
1753
  self.optimizable_actions = nn.Embedding(
1754
  num_embeddings=100, embedding_dim=60,
1755
+ )
1756
  torch.nn.init.zeros_(self.optimizable_actions.weight) #
1757
 
1758
  self.learning_rate = 5e-4
 
1770
 
1771
 
1772
  def set_init_states_target_value(self, init_states):
1773
+ # glb_rot = torch.eye(n=3, dtype=torch.float32)
1774
+ # glb_trans = torch.zeros((3,), dtype=torch.float32) ### glb_trans #### and the rot 3##
1775
 
1776
  # tot_init_states = {}
1777
  # tot_init_states['glb_rot'] = glb_rot;
1778
  # tot_init_states['glb_trans'] = glb_trans;
1779
  # tot_init_states['links_init_states'] = init_states
1780
  # self.active_robot.set_init_states_target_value(tot_init_states)
1781
+ # init_joint_states = torch.zeros((60, ), dtype=torch.float32)
1782
  self.active_robot.set_initial_state(init_states)
1783
 
1784
  def set_init_states(self):
1785
+ # glb_rot = torch.eye(n=3, dtype=torch.float32)
1786
+ # glb_trans = torch.zeros((3,), dtype=torch.float32) ### glb_trans #### and the rot 3##
1787
 
1788
  # ### random rotation ###
1789
  # # glb_rot_np = R.random().as_matrix()
1790
+ # # glb_rot = torch.from_numpy(glb_rot_np).float()
1791
  # ### random rotation ###
1792
 
1793
  # # glb_rot, glb_trans #
 
1796
  # init_states['glb_trans'] = glb_trans;
1797
  # self.active_robot.set_init_states(init_states)
1798
 
1799
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32)
1800
  self.active_robot.set_initial_state(init_joint_states)
1801
 
1802
  # cur_verts, joint_idxes = get_init_state_visual_pts(expanded_pts=False, ret_joint_idxes=True)
 
1821
 
1822
  def set_actions_and_update_states(self, actions, cur_timestep):
1823
  #
1824
+ time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)) ### time constant of the system ##
1825
  self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
1826
  pass
1827
 
1828
  def set_actions_and_update_states_v2(self, actions, cur_timestep, penetration_forces=None, sampled_visual_pts_joint_idxes=None):
1829
  #
1830
+ time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long)) ### time constant of the system ##
1831
  self.active_robot.set_actions_and_update_states_v2(actions, cur_timestep, time_cons, penetration_forces=penetration_forces, sampled_visual_pts_joint_idxes=sampled_visual_pts_joint_idxes) ###
1832
  pass
1833
 
 
1841
  timestep_to_visual_pts = {}
1842
  for i_step in range(50):
1843
  actions = {}
1844
+ actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
1845
+ actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
1846
+ actions_link_actions = torch.ones((22, ), dtype=torch.float32)
1847
  # actions_link_actions = actions_link_actions * 0.2
1848
  actions_link_actions = actions_link_actions * -1. #
1849
  actions['link_actions'] = actions_link_actions
 
1862
  # TODO: load reference points #
1863
  self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
1864
  self.ts_to_reference_pts = {
1865
+ ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float() for ts in self.ts_to_reference_pts
1866
  }
1867
 
1868
 
 
1879
  for cur_ts in range(self.n_timesteps):
1880
  # print(f"iter: {i_iter}, cur_ts: {cur_ts}")
1881
  # actions = {}
1882
+ # actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32)
1883
+ # actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32)
1884
+ actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long) + cur_ts).squeeze(0)
1885
  # actions_link_actions = actions_link_actions * 0.2
1886
  # actions_link_actions = actions_link_actions * -1. #
1887
  # actions['link_actions'] = actions_link_actions
 
2240
  # robot actions # #
2241
  # robot_actions = nn.Embedding(
2242
  # num_embeddings=num_steps, embedding_dim=22,
2243
+ # )
2244
  # torch.nn.init.zeros_(robot_actions.weight)
2245
  # # params_to_train += list(robot_actions.parameters())
2246
 
2247
  # robot_delta_states = nn.Embedding(
2248
  # num_embeddings=num_steps, embedding_dim=60,
2249
+ # )
2250
  # torch.nn.init.zeros_(robot_delta_states.weight)
2251
  # # params_to_train += list(robot_delta_states.parameters())
2252
 
2253
 
2254
  robot_states = nn.Embedding(
2255
  num_embeddings=num_steps, embedding_dim=60,
2256
+ )
2257
  torch.nn.init.zeros_(robot_states.weight)
2258
  # params_to_train += list(robot_states.parameters())
2259
 
2260
  # robot_init_states = nn.Embedding(
2261
  # num_embeddings=1, embedding_dim=22,
2262
+ # )
2263
  # torch.nn.init.zeros_(robot_init_states.weight)
2264
  # # params_to_train += list(robot_init_states.parameters())
2265
 
2266
  robot_glb_rotation = nn.Embedding(
2267
  num_embeddings=num_steps, embedding_dim=4
2268
+ )
2269
  robot_glb_rotation.weight.data[:, 0] = 1.
2270
  robot_glb_rotation.weight.data[:, 1:] = 0.
2271
 
2272
 
2273
  robot_glb_trans = nn.Embedding(
2274
  num_embeddings=num_steps, embedding_dim=3
2275
+ )
2276
  torch.nn.init.zeros_(robot_glb_trans.weight)
2277
 
2278
  ''' Load optimized MANO hand actions and states '''
 
2531
  tot_transformed_pts = []
2532
  for i_ts in range(len(blended_states)):
2533
  cur_blended_states = blended_states[i_ts]
2534
+ cur_blended_states = torch.from_numpy(cur_blended_states).float()
2535
  robot_agent.active_robot.set_delta_state_and_update_v2(cur_blended_states, 0)
2536
  cur_pts = robot_agent.get_init_state_visual_pts().detach().cpu().numpy()
2537
 
 
2568
  shadow_hand_sv_fn = "/home/xueyi/diffsim/NeuS/raw_data/scaled_redmax_hand_rescaled_grab.obj"
2569
  shadow_hand_mesh.export(shadow_hand_sv_fn)
2570
 
2571
+ init_joint_states = torch.randn((60, ), dtype=torch.float32)
2572
  robot_agent.set_initial_state(init_joint_states)
2573
 
2574
 
 
2676
  # mesh_obj.export(f"hand_urdf.ply")
2677
 
2678
  ##### Test the set initial state function #####
2679
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32)
2680
  cur_robot.set_initial_state(init_joint_states)
2681
  ##### Test the set initial state function #####
2682
 
2683
 
2684
 
2685
 
2686
+ cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32)
2687
+ cur_ones_actions = torch.ones((60, ), dtype=torch.float32) # * 100
2688
 
2689
  ts_to_mesh_verts = {}
2690
  for i_ts in range(50):
models/fields.py CHANGED
The diff for this file is too large to render. See raw diff
 
pre-requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  pip==23.3.2
2
- torch==2.2.0
 
1
  pip==23.3.2
2
+ # torch==2.2.0
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
- -f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
 
3
  # pip==20.2.4
4
  torch==2.2.0
5
  # torchvision==0.13.1
 
1
  -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
+ # -f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
3
+ -i https://download.pytorch.org/whl/cpu
4
  # pip==20.2.4
5
  torch==2.2.0
6
  # torchvision==0.13.1
scripts_demo/train_grab_pointset_points_dyn_s1.sh CHANGED
@@ -18,11 +18,15 @@ export conf=dyn_grab_pointset_points_dyn_s1.conf
18
  export conf_root="./confs_new"
19
 
20
 
21
- export data_path="./data/102_grab_all_data.npy"
22
- # bash scripts_new/train_grab_pointset_points_dyn_s1.sh
 
 
23
 
24
  export cuda_ids="0"
25
 
26
 
27
- CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case} --data_path=${data_path}
 
 
28
 
 
18
  export conf_root="./confs_new"
19
 
20
 
21
+ export data_fn="./data/102_grab_all_data.npy"
22
+
23
+
24
+ # bash scripts_demo/train_grab_pointset_points_dyn_s1.sh
25
 
26
  export cuda_ids="0"
27
 
28
 
29
+ # CUDA_VISIBLE_DEVICES=${cuda_ids}
30
+
31
+ python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case} --data_fn=${data_fn}
32