from node import InferenceNode import json import torch from PIL import Image as IMG import numpy as np from std_msgs.msg import String, Bool import argparse import h5py import os, pickle from einops import rearrange import numpy as np from PIL import Image import time """ #!/usr/bin/python3 """ import argparse import sys import threading import time import yaml from collections import deque import numpy as np import torch from cv_bridge import CvBridge from geometry_msgs.msg import Twist from nav_msgs.msg import Odometry from std_msgs.msg import Header import cv2 from scripts.agilex_model import create_model class RDTNode(InferenceNode): def __init__(self, action_chunk, instruction, ckpt_dir, unnorm_key, hz=20, max_timestep=1000, dataset_name=None, single_arm=True, lang_embed_name=''): self.ckpt_dir = ckpt_dir self.lang_embed_name = f'outs/{lang_embed_name}.pt' self.run_name = f'rdt_{ckpt_dir.split("/")[-1]}' # for video name self.single_arm = single_arm super().__init__(hz=hz, max_timestep=max_timestep, dataset_name=dataset_name, single_arm=single_arm) self.obs['language_instruction'] = f'{instruction}' self.action_chunk = action_chunk self.action_counter = 0 self.unnorm_key = unnorm_key self.prompt_sub = self._node.create_subscription(String, '/vla/prompt', self.prompt_sub, 1) self.attn = None def prompt_sub(self, msg): if self.policy is not None: img = self.obs['image'] pil_image = Image.fromarray(img) print(self.policy.inference_prompt(pil_image, msg.data)) def bringup_model(self): with open('configs/base.yaml', "r") as fp: config = yaml.safe_load(fp) self.policy = create_model( args=config, dtype=torch.bfloat16, pretrained=self.ckpt_dir, # pretrained_text_encoder_name_or_path="google/t5-v1_1-xxl", pretrained_vision_encoder_name_or_path="google/siglip-so400m-patch14-384", control_frequency=20, single_arm=self.single_arm ) self.lang_embeddings = torch.load(self.lang_embed_name)["embeddings"] def inference_fn(self): if self.single_arm: image_arrs = [ self.frame_buffer[-2], None, None, self.frame_buffer[-1], None, None # self.left_frame_buffer[-1], ] else: image_arrs = [ self.frame_buffer[-2], self.left_frame_buffer[-2], None, self.frame_buffer[-1], self.left_frame_buffer[-1], None ] images = [Image.fromarray(arr) if arr is not None else None for arr in image_arrs] if self.single_arm: proprio = torch.tensor(self.joint_pos_buffer[-1][7:]).unsqueeze(0) else: proprio = torch.tensor(self.joint_pos_buffer[-1]).unsqueeze(0) actions = self.policy.step( proprio=proprio, images=images, text_embeds=self.lang_embeddings ).squeeze(0).cpu().numpy() return actions def inference(self): if self.action_counter == 0: with torch.inference_mode(): # Len , action dim start_time = time.time() self.actions = self.inference_fn() end_time = time.time() print(f'{end_time - start_time:.6f} sec') # print(self.actions) action = self.actions[self.action_counter] # action[-1] = action[-1] * 4.0 if self.single_arm: self.joint_action(None, action) else: self.joint_action(action[:7], action[7:]) # print(action) # self.joint_action(None, ) # print(action[6], action[-1]) # self.ee_action(None, action) # self.target_ee_left += np.array(action[:6]) # self.target_ee_right += np.array(action[7:-1]) # action_target_ee_left = np.concatenate([self.target_ee_left, [action[6]]]) # action_target_ee_right = np.concatenate([self.target_ee_right, [action[-1]]]) # print(action_target_ee_right) # self.ee_action(None, action_target_ee_right) # self.ee_action(action_target_ee_left, action_target_ee_right) self.action_counter += 1 if self.action_counter == self.action_chunk: self.action_counter = 0 def done_callback(self, msg): if not self.start: ## For delta ee control if self.data_list is not None: root = h5py.File(self.data_list[self.num], 'r') skip = 5 if self.single_arm: self.target_joint_right = root['observation']['joint_pos'][skip, :7] self.joint_action(None, self.target_joint_right) else: self.target_joint_left = root['observation']['joint_pos'][skip, :7] self.target_joint_right = root['observation']['joint_pos'][skip, 7:] self.joint_action(self.target_joint_left, self.target_joint_right) time.sleep(2) else: self.target_ee_left = self.obs['left_pose'] self.target_ee_right = self.obs['right_pose'] print('Inference & Video Recording Start') self.start = True msg = Bool() msg.data = True self.sync_pub.publish(msg) self.window.video_start() else: self.start = False msg = Bool() msg.data = False self.sync_pub.publish(msg) self.init_robot() self.action_counter = 0 if self.window.video_recording: self.window.video_stop() self.initialize() print('Next Inference Ready') if __name__ == "__main__": import cv2 ckpt_dir = '/home/univ/workspace/rdt-ckpts/checkpoint-38000' action_chunk = 64 hz = 20 instruction = 'handover the stuffed doll' unnorm_key = 'handover_kirby' single_arm = False dataset_name = [ 'vla_upright_mug', 'vla_sweep_screws', 'vla_pick_ball_place_bin', 'twinvla_handover_kirby', 'twinvla_put_bottle', 'twinvla_detach_ball', 'twinvla_tear_paper_towel' ] lang_embed_name = [ 'upright_mug', 'sweep_screws', 'pick_ball_place_bin', 'handover_kirby' ] num = 3 node = RDTNode( action_chunk=action_chunk, instruction=instruction, ckpt_dir=ckpt_dir, unnorm_key=unnorm_key, hz=hz, max_timestep=1000, dataset_name=dataset_name[num], lang_embed_name=lang_embed_name[num], single_arm=single_arm ) while True: try: if node.single_arm: img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) else: left_img = cv2.cvtColor(node.obs['leftview_image'], cv2.COLOR_BGR2RGB) right_img = cv2.cvtColor(node.obs['image'], cv2.COLOR_BGR2RGB) img = cv2.hconcat([left_img, right_img]) if node.start: node.window.show(img, overlay_img=None, text=node.obs['language_instruction']) else: # print(node.attn) node.boundary_query() node.window.show(img, overlay_img=node.overlay_img, text=node.obs['language_instruction'], grid=node.grid) except KeyboardInterrupt: node.ros_close() except Exception as e: print(f"An error occurred: {e}") # node.ros_close()