openpi-oft / src /openpi /training /droid_rlds_dataset.py
Sichang0621's picture
Upload folder using huggingface_hub
ce5618e verified
"""
RLDS-based data loader for DROID.
While openpi typically uses LeRobot's data loader, it is not currently scalable enough for larger datasets like DROID.
Thus, we provide a data loader example here that uses the RLDS data format.
The data loader also applies a few DROID-specific data filters / transformations.
"""
from enum import Enum
from enum import auto
class DroidActionSpace(Enum):
"""Action space for DROID dataset."""
JOINT_POSITION = auto()
JOINT_VELOCITY = auto()
class DroidRldsDataset:
def __init__(
self,
data_dir: str,
batch_size: int,
*, # Force keyword-only arguments
shuffle: bool = True,
action_chunk_size: int = 16,
# We default to joint position actions, since they allow policy evaluation in simulation.
action_space: DroidActionSpace = DroidActionSpace.JOINT_POSITION,
max_loaded_steps_per_episode: int = 100,
# Reduce this if you are running out of memory, but careful -- below ~100k shuffling is not sufficiently random.
shuffle_buffer_size: int = 250_000,
num_parallel_reads: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
num_parallel_calls: int = -1, # -1 == tf.data.AUTOTUNE -- hack to not import tf at top level
):
# Import tensorflow here to not make it mandatory in case RLDS data loader is not used.
import dlimp as dl
import tensorflow as tf
import tensorflow_datasets as tfds
# Configure Tensorflow with *no GPU devices* (to prevent clobber with PyTorch / JAX)
tf.config.set_visible_devices([], "GPU")
builder = tfds.builder("droid", data_dir=data_dir)
dataset = dl.DLataset.from_rlds(builder, split="train", shuffle=shuffle, num_parallel_reads=num_parallel_reads)
# Filter out any unsuccessful trajectories -- we use the file name to check this
dataset = dataset.filter(
lambda traj: tf.strings.regex_full_match(
traj["traj_metadata"]["episode_metadata"]["file_path"][0], ".*success.*"
)
)
# Repeat dataset so we never run out of data.
dataset = dataset.repeat()
def restructure(traj):
"""Reformat observation and action keys, sample language instruction."""
# Important: we use joint *position* action space -- easier to simulate!
actions = tf.concat(
(
(
traj["action_dict"]["joint_position"]
if action_space == DroidActionSpace.JOINT_POSITION
else traj["action_dict"]["joint_velocity"]
),
traj["action_dict"]["gripper_position"],
),
axis=-1,
)
# Randomly samples one of the two exterior images in DROID during training (we only train with one at a time).
# Note: the "left" refers to the left camera in the stereo pair, we only train on the left camera.
exterior_img = tf.cond(
tf.random.uniform(shape=[]) > 0.5,
lambda: traj["observation"]["exterior_image_1_left"],
lambda: traj["observation"]["exterior_image_2_left"],
)
wrist_img = traj["observation"]["wrist_image_left"]
# Randomly sample one of the three language instructions
instruction = tf.random.shuffle(
[traj["language_instruction"], traj["language_instruction_2"], traj["language_instruction_3"]]
)[0]
return {
"actions": actions,
"observation": {
"image": exterior_img,
"wrist_image": wrist_img,
"joint_position": traj["observation"]["joint_position"],
"gripper_position": traj["observation"]["gripper_position"],
},
"prompt": instruction,
}
dataset = dataset.traj_map(restructure, num_parallel_calls)
def chunk_actions(traj):
"""Splits episode into action chunks."""
traj_len = tf.shape(traj["actions"])[0]
# For each step in the trajectory, construct indices for the next n actions
action_chunk_indices = tf.broadcast_to(
tf.range(action_chunk_size)[None],
[traj_len, action_chunk_size],
) + tf.broadcast_to(
tf.range(traj_len)[:, None],
[traj_len, action_chunk_size],
)
# Cap to length of the sequence --> final chunks will repeat the last action
# This makes sense, since we are using absolute joint + gripper position actions
action_chunk_indices = tf.minimum(action_chunk_indices, traj_len - 1)
# Gather the actions for each chunk
traj["actions"] = tf.gather(traj["actions"], action_chunk_indices)
return traj
dataset = dataset.traj_map(chunk_actions, num_parallel_calls)
def filter_idle(traj):
"""Filter out chunks with idle actions.
--> we filter if at least first half of chunk does not move.
"""
if action_space == DroidActionSpace.JOINT_POSITION:
# Compute delta to first position in action chunk
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2] - traj["actions"][:1]) > 1e-3)
return tf.reduce_any(tf.abs(traj["actions"][: action_chunk_size // 2]) > 1e-3)
dataset = dataset.filter(filter_idle)
# Flatten: map from trajectory dataset to dataset of individual action chunks
dataset = dataset.flatten(num_parallel_calls=num_parallel_calls)
# Decode images: RLDS saves encoded images, only decode now for efficiency
def decode_images(traj):
traj["observation"]["image"] = tf.io.decode_image(
traj["observation"]["image"], expand_animations=False, dtype=tf.uint8
)
traj["observation"]["wrist_image"] = tf.io.decode_image(
traj["observation"]["wrist_image"], expand_animations=False, dtype=tf.uint8
)
return traj
dataset = dataset.frame_map(decode_images, num_parallel_calls)
# Shuffle, batch
dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size)
# Note =>> Seems to reduce memory usage without affecting speed?
dataset = dataset.with_ram_budget(1)
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
yield from self.dataset.as_numpy_iterator()
def __len__(self):
# This is the approximate number of samples in DROID after filtering.
# Easier to hardcode than to iterate through the dataset and compute it.
return 20_000_000