|
""" |
|
RLDS-based data loader for DROID. |
|
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID. |
|
Thus, we provide a data loader example here that uses the RLDS data format. |
|
The data loader also applies a few DROID-specific data filters / transformations. |
|
""" |
|
|
|
from enum import Enum |
|
from enum import auto |
|
|
|
|
|
class DroidActionSpace(Enum): |
|
"""Action space for DROID dataset.""" |
|
|
|
JOINT_POSITION = auto() |
|
JOINT_VELOCITY = auto() |
|
|
|
|
|
class DroidRldsDataset: |
|
def __init__( |
|
self, |
|
data_dir: str, |
|
batch_size: int, |
|
*, |
|
shuffle: bool = True, |
|
action_chunk_size: int = 16, |
|
|
|
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION, |
|
max_loaded_steps_per_episode: int = 100, |
|
|
|
shuffle_buffer_size: int = 250_000, |
|
num_parallel_reads: int = -1, |
|
num_parallel_calls: int = -1, |
|
): |
|
|
|
import dlimp as dl |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
tf.config.set_visible_devices([], "GPU") |
|
|
|
builder = tfds.builder("droid", data_dir=data_dir) |
|
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads) |
|
|
|
|
|
dataset = dataset.filter( |
|
lambda traj: tf.strings.regex_full_match( |
|
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*" |
|
) |
|
) |
|
|
|
|
|
dataset = dataset.repeat() |
|
|
|
def restructure(traj): |
|
"""Reformat observation and action keys, sample language instruction.""" |
|
|
|
actions = tf.concat( |
|
( |
|
( |
|
traj["action_dict"]["joint_position"] |
|
if action_space == DroidActionSpace.JOINT_POSITION |
|
else traj["action_dict"]["joint_velocity"] |
|
), |
|
traj["action_dict"]["gripper_position"], |
|
), |
|
axis=-1, |
|
) |
|
|
|
|
|
exterior_img = tf.cond( |
|
tf.random.uniform(shape=[]) > 0.5, |
|
lambda: traj["observation"]["exterior_image_1_left"], |
|
lambda: traj["observation"]["exterior_image_2_left"], |
|
) |
|
wrist_img = traj["observation"]["wrist_image_left"] |
|
|
|
instruction = tf.random.shuffle( |
|
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]] |
|
)[0] |
|
|
|
return { |
|
"actions": actions, |
|
"observation": { |
|
"image": exterior_img, |
|
"wrist_image": wrist_img, |
|
"joint_position": traj["observation"]["joint_position"], |
|
"gripper_position": traj["observation"]["gripper_position"], |
|
}, |
|
"prompt": instruction, |
|
} |
|
|
|
dataset = dataset.traj_map(restructure, num_parallel_calls) |
|
|
|
def chunk_actions(traj): |
|
"""Splits episode into action chunks.""" |
|
traj_len = tf.shape(traj["actions"])[0] |
|
|
|
|
|
action_chunk_indices = tf.broadcast_to( |
|
tf.range(action_chunk_size)[None], |
|
[traj_len, action_chunk_size], |
|
) + tf.broadcast_to( |
|
tf.range(traj_len)[:, None], |
|
[traj_len, action_chunk_size], |
|
) |
|
|
|
|
|
|
|
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1) |
|
|
|
|
|
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices) |
|
return traj |
|
|
|
dataset = dataset.traj_map(chunk_actions, num_parallel_calls) |
|
|
|
def filter_idle(traj): |
|
"""Filter out chunks with idle actions. |
|
--> we filter if at least first half of chunk does not move. |
|
""" |
|
if action_space == DroidActionSpace.JOINT_POSITION: |
|
|
|
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2] - traj["actions"][:1]) > 1e-3) |
|
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2]) > 1e-3) |
|
|
|
dataset = dataset.filter(filter_idle) |
|
|
|
|
|
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls) |
|
|
|
|
|
def decode_images(traj): |
|
traj["observation"]["image"] = tf.io.decode_image( |
|
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8 |
|
) |
|
traj["observation"]["wrist_image"] = tf.io.decode_image( |
|
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8 |
|
) |
|
return traj |
|
|
|
dataset = dataset.frame_map(decode_images, num_parallel_calls) |
|
|
|
|
|
dataset = dataset.shuffle(shuffle_buffer_size) |
|
dataset = dataset.batch(batch_size) |
|
|
|
dataset = dataset.with_ram_budget(1) |
|
|
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.shuffle = shuffle |
|
|
|
def __iter__(self): |
|
yield from self.dataset.as_numpy_iterator() |
|
|
|
def __len__(self): |
|
|
|
|
|
return 20_000_000 |
|
|