|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data |
|
|
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and |
|
|
transforms are applied to the observation images before they are returned in the dataset's __getitem__. |
|
|
""" |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
from torchvision.transforms import ToPILImage, v2 |
|
|
|
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset |
|
|
|
|
|
dataset_repo_id = "lerobot/aloha_static_screw_driver" |
|
|
|
|
|
|
|
|
dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) |
|
|
|
|
|
|
|
|
|
|
|
first_idx = dataset.episode_data_index["from"][0].item() |
|
|
|
|
|
|
|
|
frame = dataset[first_idx][dataset.meta.camera_keys[0]] |
|
|
|
|
|
|
|
|
|
|
|
transforms = v2.Compose( |
|
|
[ |
|
|
v2.ColorJitter(brightness=(0.5, 1.5)), |
|
|
v2.ColorJitter(contrast=(0.5, 1.5)), |
|
|
v2.ColorJitter(hue=(-0.1, 0.1)), |
|
|
v2.RandomAdjustSharpness(sharpness_factor=2, p=1), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) |
|
|
|
|
|
|
|
|
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] |
|
|
|
|
|
|
|
|
output_dir = Path("outputs/image_transforms") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
to_pil = ToPILImage() |
|
|
to_pil(frame).save(output_dir / "original_frame.png", quality=100) |
|
|
print(f"Original frame saved to {output_dir / 'original_frame.png'}.") |
|
|
|
|
|
|
|
|
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) |
|
|
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.") |
|
|
|