import pybullet as p import PySimpleGUI as sg import pickle from os import getcwd from urdfpy import URDF from os.path import abspath, dirname, basename, splitext from transforms3d.affines import decompose from transforms3d.quaternions import mat2quat import numpy as np class PyBulletRecorder: class LinkTracker: def __init__(self, name, body_id, link_id, link_origin, mesh_path, mesh_scale, mesh_material=None): self.body_id = body_id self.link_id = link_id self.mesh_path = mesh_path self.mesh_scale = mesh_scale self.mesh_material = mesh_material decomposed_origin = decompose(link_origin) orn = mat2quat(decomposed_origin[1]) orn = [orn[1], orn[2], orn[3], orn[0]] self.link_pose = [decomposed_origin[0], orn] self.name = name def transform(self, position, orientation): return p.multiplyTransforms( position, orientation, self.link_pose[0], self.link_pose[1], ) def get_keyframe(self): if self.link_id == -1: position, orientation = p.getBasePositionAndOrientation( self.body_id) position, orientation = self.transform( position=position, orientation=orientation) else: link_state = p.getLinkState(self.body_id, self.link_id, computeForwardKinematics=True) position, orientation = self.transform( position=link_state[4], orientation=link_state[5]) return { 'position': list(position), 'orientation': list(orientation) } def __init__(self): self.states = [] self.links = [] def register_object(self, body_id, urdf_path, global_scaling=1, color=None): link_id_map = dict() n = p.getNumJoints(body_id) link_id_map[str(p.getBodyInfo(body_id)[0].decode('gb2312'))] = -1 for link_id in range(0, n): link_id_map[str(p.getJointInfo(body_id, link_id)[ 12].decode('gb2312'))] = link_id dir_path = dirname(abspath(urdf_path)) file_name = splitext(basename(urdf_path))[0] robot = URDF.load(urdf_path) for link in robot.links: # print("robot link:", body_id, link.name, link_id_map.keys()) if link.name not in link_id_map: print("skip links !! ", link.name, link_id_map, len(robot.links), p.getBodyInfo(body_id)[0].decode('gb2312')) continue link_id = link_id_map[link.name] if len(link.visuals) > 0: for i, link_visual in enumerate(link.visuals): mesh_material = None if link_visual.material is not None: mesh_material = link_visual.material if color is not None: mesh_material.name = mesh_material.name + f"_{np.random.randint(100)}" # mark it mesh_material.color = color if link_visual.geometry.mesh is not None: print("use mesh", i, link_id_map.keys()) mesh_scale = [global_scaling, global_scaling, global_scaling]\ if link_visual.geometry.mesh.scale is None \ else link_visual.geometry.mesh.scale * global_scaling self.links.append(('mesh', PyBulletRecorder.LinkTracker( name=file_name + f'_{body_id}_{link.name}_{i}', body_id=body_id, link_id=link_id, link_origin= # If link_id == -1 then is base link, # PyBullet will return # inertial_origin @ visual_origin, # so need to undo that transform (np.linalg.inv(link.inertial.origin) if link_id == -1 else np.identity(4)) @ link_visual.origin * global_scaling, mesh_path=dir_path + '/' + link_visual.geometry.mesh.filename, mesh_scale=mesh_scale, mesh_material=mesh_material))) if link_visual.geometry.box is not None: print("use box", i, link_id_map.keys(), link_visual.geometry.box.__dict__) # import IPython; IPython.embed() mesh_scale = link_visual.geometry.box.size / 2 self.links.append(('box', PyBulletRecorder.LinkTracker( name=file_name + f'_{body_id}_{link.name}_{i}', body_id=body_id, link_id=link_id, link_origin= (np.linalg.inv(link.inertial.origin) if link_id == -1 else np.identity(4)) @ link_visual.origin * global_scaling, mesh_path='box', mesh_scale=mesh_scale, mesh_material=mesh_material))) if link_visual.geometry.cylinder is not None: print("use cylinder", i, link_id_map.keys(), link_visual.geometry.cylinder.__dict__) mesh_scale = [link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.radius, link_visual.geometry.cylinder.length] self.links.append(('cylinder', PyBulletRecorder.LinkTracker( name=file_name + f'_{body_id}_{link.name}_{i}', body_id=body_id, link_id=link_id, link_origin= (np.linalg.inv(link.inertial.origin) if link_id == -1 else np.identity(4)) @ link_visual.origin * global_scaling, mesh_path='cylinder', mesh_scale=mesh_scale, mesh_material=mesh_material))) if link_visual.geometry.sphere is not None: print("use sphere", i, link_id_map.keys(), link_visual.geometry.sphere.__dict__) mesh_scale = [link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius, link_visual.geometry.sphere.radius] self.links.append(('sphere', PyBulletRecorder.LinkTracker( name=file_name + f'_{body_id}_{link.name}_{i}', body_id=body_id, link_id=link_id, link_origin= (np.linalg.inv(link.inertial.origin) if link_id == -1 else np.identity(4)) @ link_visual.origin * global_scaling, mesh_path='sphere', mesh_scale=mesh_scale, mesh_material=mesh_material))) def add_keyframe(self): # Ideally, call every p.stepSimulation() current_state = {} for name, link in self.links: current_state[link.name] = link.get_keyframe() self.states.append(current_state) def prompt_save(self): layout = [[sg.Text('Do you want to save previous episode?')], [sg.Button('Yes'), sg.Button('No')]] window = sg.Window('PyBullet Recorder', layout) save = False while True: event, values = window.read() if event in (None, 'No'): break elif event == 'Yes': save = True break window.close() if save: layout = [[sg.Text('Where do you want to save it?')], [sg.Text('Path'), sg.InputText(getcwd())], [sg.Button('OK')]] window = sg.Window('PyBullet Recorder', layout) event, values = window.read() window.close() self.save(values[0]) self.reset() def reset(self): self.states = [] def get_formatted_output(self): retval = {} for geo_name, link in self.links: if geo_name == 'mesh': retval[link.name] = { 'type': 'mesh', 'mesh_path': link.mesh_path, 'mesh_scale': link.mesh_scale, 'frames': [state[link.name] for state in self.states] } if geo_name == 'box': # print("retval: box!") retval[link.name] = { 'type': 'cube', 'name': link.name, 'mesh_scale': link.mesh_scale, 'frames': [state[link.name] for state in self.states] } if geo_name == 'cylinder': retval[link.name] = { 'type': 'cylinder', 'name': link.name, 'mesh_scale': link.mesh_scale, 'frames': [state[link.name] for state in self.states] } if geo_name == 'sphere': retval[link.name] = { 'type': 'sphere', 'name': link.name, 'mesh_scale': link.mesh_scale, 'frames': [state[link.name] for state in self.states] } if link.mesh_material is not None: retval[link.name]['mesh_material_name'] = link.mesh_material.name retval[link.name] ['mesh_material_color'] = link.mesh_material.color return retval def save(self, path): if path is None: print("[Recorder] Path is None.. not saving") else: print("[Recorder] Saving state to {}".format(path)) pickle.dump(self.get_formatted_output(), open(path, 'wb'))