""" RLDS-based data loader for DROID. While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID. Thus, we provide a data loader example here that uses the RLDS data format. The data loader also applies a few DROID-specific data filters / transformations. """ from enum import Enum from enum import auto class DroidActionSpace(Enum): """Action space for DROID dataset.""" JOINT_POSITION = auto() JOINT_VELOCITY = auto() class DroidRldsDataset: def __init__( self, data_dir: str, batch_size: int, *, # Force keyword-only arguments shuffle: bool = True, action_chunk_size: int = 16, # We default to joint position actions, since they allow policy evaluation in simulation. action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, max_loaded_steps_per_episode: int = 100, # Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random. shuffle_buffer_size: int = 250_000, num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level ): # Import tensorflow here to not make it mandatory in case RLDS data loader is not used. import dlimp as dl import tensorflow as tf import tensorflow_datasets as tfds # Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX) tf.config.set_visible_devices([], "GPU") builder = tfds.builder("droid", data_dir=data_dir) dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads) # Filter out any unsuccessful trajectories -- we use the file name to check this dataset = dataset.filter( lambda traj: tf.strings.regex_full_match( traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" ) ) # Repeat dataset so we never run out of data. dataset = dataset.repeat() def restructure(traj): """Reformat observation and action keys, sample language instruction.""" # Important: we use joint *position* action space -- easier to simulate! actions = tf.concat( ( ( traj["action_dict"]["joint_position"] if action_space == DroidActionSpace.JOINT_POSITION else traj["action_dict"]["joint_velocity"] ), traj["action_dict"]["gripper_position"], ), axis=-1, ) # Randomly samples one of the two exterior images in DROID during training (we only train with one at a time). # Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera. exterior_img = tf.cond( tf.random.uniform(shape=[]) > 0.5, lambda: traj["observation"]["exterior_image_1_left"], lambda: traj["observation"]["exterior_image_2_left"], ) wrist_img = traj["observation"]["wrist_image_left"] # Randomly sample one of the three language instructions instruction = tf.random.shuffle( [traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] )[0] return { "actions": actions, "observation": { "image": exterior_img, "wrist_image": wrist_img, "joint_position": traj["observation"]["joint_position"], "gripper_position": traj["observation"]["gripper_position"], }, "prompt": instruction, } dataset = dataset.traj_map(restructure, num_parallel_calls) def chunk_actions(traj): """Splits episode into action chunks.""" traj_len = tf.shape(traj["actions"])[0] # For each step in the trajectory, construct indices for the next n actions action_chunk_indices = tf.broadcast_to( tf.range(action_chunk_size)[None], [traj_len, action_chunk_size], ) + tf.broadcast_to( tf.range(traj_len)[:, None], [traj_len, action_chunk_size], ) # Cap to length of the sequence --> final chunks will repeat the last action # This makes sense, since we are using absolute joint + gripper position actions action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) # Gather the actions for each chunk traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) return traj dataset = dataset.traj_map(chunk_actions, num_parallel_calls) def filter_idle(traj): """Filter out chunks with idle actions. --> we filter if at least first half of chunk does not move. """ if action_space == DroidActionSpace.JOINT_POSITION: # Compute delta to first position in action chunk return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2] - traj["actions"][:1]) > 1e-3) return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2]) > 1e-3) dataset = dataset.filter(filter_idle) # Flatten: map from trajectory dataset to dataset of individual action chunks dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) # Decode images: RLDS saves encoded images, only decode now for efficiency def decode_images(traj): traj["observation"]["image"] = tf.io.decode_image( traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 ) traj["observation"]["wrist_image"] = tf.io.decode_image( traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 ) return traj dataset = dataset.frame_map(decode_images, num_parallel_calls) # Shuffle, batch dataset = dataset.shuffle(shuffle_buffer_size) dataset = dataset.batch(batch_size) # Note =>> Seems to reduce memory usage without affecting speed? dataset = dataset.with_ram_budget(1) self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle def __iter__(self): yield from self.dataset.as_numpy_iterator() def __len__(self): # This is the approximate number of samples in DROID after filtering. # Easier to hardcode than to iterate through the dataset and compute it. return 20_000_000